MakiPan commited on
Commit
6d21cff
·
1 Parent(s): 3dd4371

Update app.py

Browse files

added radio i think

Files changed (1) hide show
  1. app.py +75 -37
app.py CHANGED
@@ -86,10 +86,9 @@ def generate_annotation(img, overlap=False, hand_encoding=False):
86
  annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding)
87
  return annotated_image
88
 
89
- model_type = gr.Radio(["Standard", "Hand Encoding"], label="Model preprocessing", info="We developed two models, one with standard mediapipe landmarks, and one with different (but similar) coloring on palm landmards to distinguish left and right")
90
- model_type="Standard"
91
- if model_type=="Standard":
92
- args = Namespace(
93
  pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
94
  revision="non-ema",
95
  from_pt=True,
@@ -97,8 +96,7 @@ if model_type=="Standard":
97
  controlnet_revision=None,
98
  controlnet_from_pt=False,
99
  )
100
- if model_type=="Hand Encoding":
101
- args = Namespace(
102
  pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
103
  revision="non-ema",
104
  from_pt=True,
@@ -107,35 +105,58 @@ if model_type=="Hand Encoding":
107
  controlnet_from_pt=False,
108
  )
109
 
110
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
111
- args.controlnet_model_name_or_path,
112
- revision=args.controlnet_revision,
113
- from_pt=args.controlnet_from_pt,
 
 
 
 
 
 
114
  dtype=jnp.float32, # jnp.bfloat16
115
  )
116
 
117
- pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
118
- args.pretrained_model_name_or_path,
 
 
119
  # tokenizer=tokenizer,
120
- controlnet=controlnet,
121
  safety_checker=None,
122
  dtype=jnp.float32, # jnp.bfloat16
123
- revision=args.revision,
124
- from_pt=args.from_pt,
125
  )
 
 
 
 
 
 
 
 
 
 
126
 
 
 
127
 
128
- pipeline_params["controlnet"] = controlnet_params
129
- pipeline_params = jax_utils.replicate(pipeline_params)
130
 
131
  rng = jax.random.PRNGKey(0)
132
  num_samples = jax.device_count()
133
  prng_seed = jax.random.split(rng, jax.device_count())
134
 
135
 
136
- def infer(prompt, negative_prompt, image):
137
  prompts = num_samples * [prompt]
138
- prompt_ids = pipeline.prepare_text_inputs(prompts)
 
 
 
139
  prompt_ids = shard(prompt_ids)
140
 
141
  if model_type=="Standard":
@@ -145,21 +166,39 @@ def infer(prompt, negative_prompt, image):
145
  annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
146
  overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
147
  validation_image = Image.fromarray(annotated_image).convert("RGB")
148
- processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
149
- processed_image = shard(processed_image)
150
 
151
- negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
152
- negative_prompt_ids = shard(negative_prompt_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- images = pipeline(
155
- prompt_ids=prompt_ids,
156
- image=processed_image,
157
- params=pipeline_params,
158
- prng_seed=prng_seed,
159
- num_inference_steps=50,
160
- neg_prompt_ids=negative_prompt_ids,
161
- jit=True,
162
- ).images
 
 
 
163
 
164
 
165
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
@@ -176,16 +215,15 @@ with gr.Blocks(theme='gradio/soft') as demo:
176
  Model1 can be found at [https://huggingface.co/Vincent-luo/controlnet-hands](https://huggingface.co/Vincent-luo/controlnet-hands)
177
 
178
  Model2 can be found at [https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/ ](https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/)
179
-
180
  Dataset1 can be found at [https://huggingface.co/datasets/MakiPan/hagrid250k-blip2](https://huggingface.co/datasets/MakiPan/hagrid250k-blip2)
181
 
182
  Dataset2 can be found at [https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k](https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k)
183
 
184
  Preprocessing1 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py)
185
-
186
  Preprocessing2 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py)
187
  """)
188
-
 
189
  with gr.Row():
190
  with gr.Column():
191
  prompt_input = gr.Textbox(label="Prompt")
@@ -227,13 +265,13 @@ with gr.Blocks(theme='gradio/soft') as demo:
227
  "example4.png"
228
  ],
229
  ],
230
- inputs=[prompt_input, negative_prompt, input_image],
231
  outputs=[output_image],
232
  fn=infer,
233
  cache_examples=True,
234
  )
235
 
236
- inputs = [prompt_input, negative_prompt, input_image]
237
  submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
238
 
239
  demo.launch()
 
86
  annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding)
87
  return annotated_image
88
 
89
+
90
+
91
+ std_args = Namespace(
 
92
  pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
93
  revision="non-ema",
94
  from_pt=True,
 
96
  controlnet_revision=None,
97
  controlnet_from_pt=False,
98
  )
99
+ enc_args = Namespace(
 
100
  pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
101
  revision="non-ema",
102
  from_pt=True,
 
105
  controlnet_from_pt=False,
106
  )
107
 
108
+ std_controlnet, std_controlnet_params = FlaxControlNetModel.from_pretrained(
109
+ std_args.controlnet_model_name_or_path,
110
+ revision=std_args.controlnet_revision,
111
+ from_pt=std_args.controlnet_from_pt,
112
+ dtype=jnp.float32, # jnp.bfloat16
113
+ )
114
+ enc_controlnet, enc_controlnet_params = FlaxControlNetModel.from_pretrained(
115
+ enc_args.controlnet_model_name_or_path,
116
+ revision=enc_args.controlnet_revision,
117
+ from_pt=enc_args.controlnet_from_pt,
118
  dtype=jnp.float32, # jnp.bfloat16
119
  )
120
 
121
+
122
+
123
+ std_pipeline, std_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
124
+ std_args.pretrained_model_name_or_path,
125
  # tokenizer=tokenizer,
126
+ controlnet=std_controlnet,
127
  safety_checker=None,
128
  dtype=jnp.float32, # jnp.bfloat16
129
+ revision=std_args.revision,
130
+ from_pt=std_args.from_pt,
131
  )
132
+ enc_pipeline, enc_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
133
+ enc_args.pretrained_model_name_or_path,
134
+ # tokenizer=tokenizer,
135
+ controlnet=enc_controlnet,
136
+ safety_checker=None,
137
+ dtype=jnp.float32, # jnp.bfloat16
138
+ revision=enc_args.revision,
139
+ from_pt=enc_args.from_pt,
140
+ )
141
+
142
 
143
+ std_pipeline_params["controlnet"] = std_controlnet_params
144
+ std_pipeline_params = jax_utils.replicate(std_pipeline_params)
145
 
146
+ enc_pipeline_params["controlnet"] = enc_controlnet_params
147
+ enc_pipeline_params = jax_utils.replicate(enc_pipeline_params)
148
 
149
  rng = jax.random.PRNGKey(0)
150
  num_samples = jax.device_count()
151
  prng_seed = jax.random.split(rng, jax.device_count())
152
 
153
 
154
+ def infer(prompt, negative_prompt, image, model_type="Standard"):
155
  prompts = num_samples * [prompt]
156
+ if model_type=="Standard":
157
+ prompt_ids = std_pipeline.prepare_text_inputs(prompts)
158
+ if model_type=="Hand Encoding":
159
+ prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
160
  prompt_ids = shard(prompt_ids)
161
 
162
  if model_type=="Standard":
 
166
  annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
167
  overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
168
  validation_image = Image.fromarray(annotated_image).convert("RGB")
 
 
169
 
170
+ if model_type=="Standard":
171
+ processed_image = std_pipeline.prepare_image_inputs(num_samples * [validation_image])
172
+ processed_image = shard(processed_image)
173
+
174
+ negative_prompt_ids = std_pipeline.prepare_text_inputs([negative_prompt] * num_samples)
175
+ negative_prompt_ids = shard(negative_prompt_ids)
176
+
177
+ images = std_pipeline(
178
+ prompt_ids=prompt_ids,
179
+ image=processed_image,
180
+ params=std_pipeline_params,
181
+ prng_seed=prng_seed,
182
+ num_inference_steps=50,
183
+ neg_prompt_ids=negative_prompt_ids,
184
+ jit=True,
185
+ ).images
186
+ if model_type=="Hand Encoding":
187
+ processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
188
+ processed_image = shard(processed_image)
189
 
190
+ negative_prompt_ids = enc_pipeline.prepare_text_inputs([negative_prompt] * num_samples)
191
+ negative_prompt_ids = shard(negative_prompt_ids)
192
+
193
+ images = enc_pipeline(
194
+ prompt_ids=prompt_ids,
195
+ image=processed_image,
196
+ params=enc_pipeline_params,
197
+ prng_seed=prng_seed,
198
+ num_inference_steps=50,
199
+ neg_prompt_ids=negative_prompt_ids,
200
+ jit=True,
201
+ ).images
202
 
203
 
204
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
 
215
  Model1 can be found at [https://huggingface.co/Vincent-luo/controlnet-hands](https://huggingface.co/Vincent-luo/controlnet-hands)
216
 
217
  Model2 can be found at [https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/ ](https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/)
 
218
  Dataset1 can be found at [https://huggingface.co/datasets/MakiPan/hagrid250k-blip2](https://huggingface.co/datasets/MakiPan/hagrid250k-blip2)
219
 
220
  Dataset2 can be found at [https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k](https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k)
221
 
222
  Preprocessing1 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py)
 
223
  Preprocessing2 can be found at [https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py](https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py)
224
  """)
225
+ model_type = gr.Radio(["Standard", "Hand Encoding"], label="Model preprocessing", info="We developed two models, one with standard mediapipe landmarks, and one with different (but similar) coloring on palm landmards to distinguish left and right")
226
+
227
  with gr.Row():
228
  with gr.Column():
229
  prompt_input = gr.Textbox(label="Prompt")
 
265
  "example4.png"
266
  ],
267
  ],
268
+ inputs=[prompt_input, negative_prompt, input_image, model_type],
269
  outputs=[output_image],
270
  fn=infer,
271
  cache_examples=True,
272
  )
273
 
274
+ inputs = [prompt_input, negative_prompt, input_image, model_type]
275
  submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
276
 
277
  demo.launch()