Jayabalambika commited on
Commit
eafe433
·
1 Parent(s): 6d21cff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -155,16 +155,21 @@ def infer(prompt, negative_prompt, image, model_type="Standard"):
155
  prompts = num_samples * [prompt]
156
  if model_type=="Standard":
157
  prompt_ids = std_pipeline.prepare_text_inputs(prompts)
158
- if model_type=="Hand Encoding":
159
  prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
 
 
160
  prompt_ids = shard(prompt_ids)
161
 
162
  if model_type=="Standard":
163
  annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
164
  overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
165
- if model_type=="Hand Encoding":
166
  annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
167
  overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
 
 
 
168
  validation_image = Image.fromarray(annotated_image).convert("RGB")
169
 
170
  if model_type=="Standard":
@@ -183,7 +188,7 @@ def infer(prompt, negative_prompt, image, model_type="Standard"):
183
  neg_prompt_ids=negative_prompt_ids,
184
  jit=True,
185
  ).images
186
- if model_type=="Hand Encoding":
187
  processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
188
  processed_image = shard(processed_image)
189
 
@@ -200,7 +205,8 @@ def infer(prompt, negative_prompt, image, model_type="Standard"):
200
  jit=True,
201
  ).images
202
 
203
-
 
204
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
205
 
206
  results = [i for i in images]
 
155
  prompts = num_samples * [prompt]
156
  if model_type=="Standard":
157
  prompt_ids = std_pipeline.prepare_text_inputs(prompts)
158
+ elif model_type=="Hand Encoding":
159
  prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
160
+ else:
161
+ pass
162
  prompt_ids = shard(prompt_ids)
163
 
164
  if model_type=="Standard":
165
  annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
166
  overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
167
+ elif model_type=="Hand Encoding":
168
  annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
169
  overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
170
+
171
+ else:
172
+ pass
173
  validation_image = Image.fromarray(annotated_image).convert("RGB")
174
 
175
  if model_type=="Standard":
 
188
  neg_prompt_ids=negative_prompt_ids,
189
  jit=True,
190
  ).images
191
+ elif model_type=="Hand Encoding":
192
  processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
193
  processed_image = shard(processed_image)
194
 
 
205
  jit=True,
206
  ).images
207
 
208
+ else:
209
+ pass
210
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
211
 
212
  results = [i for i in images]