MakiPan commited on
Commit
58ac711
·
1 Parent(s): 6ab4143

Update app.py

Browse files

added with gr.row for radio and added conditionals depending on model selection

Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -85,17 +85,27 @@ def generate_annotation(img, overlap=False, hand_encoding=False):
85
  # STEP 5: Process the classification result. In this case, visualize it.
86
  annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding)
87
  return annotated_image
88
-
89
- model_type = gr.Radio(["Standard", "Hand Encoding"], label="Model preprocessing", info="We developed two models, one with standard mediapipe landmarks, and one with different (but similar) coloring on palm landmards to distinguish left and right")
90
-
91
- args = Namespace(
92
- pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
93
- revision="non-ema",
94
- from_pt=True,
95
- controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
96
- controlnet_revision=None,
97
- controlnet_from_pt=False,
98
- )
 
 
 
 
 
 
 
 
 
 
99
 
100
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
101
  args.controlnet_model_name_or_path,
@@ -128,7 +138,12 @@ def infer(prompt, negative_prompt, image):
128
  prompt_ids = pipeline.prepare_text_inputs(prompts)
129
  prompt_ids = shard(prompt_ids)
130
 
131
- annotated_image = generate_annotation(image)
 
 
 
 
 
132
  validation_image = Image.fromarray(annotated_image).convert("RGB")
133
  processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
134
  processed_image = shard(processed_image)
@@ -150,7 +165,7 @@ def infer(prompt, negative_prompt, image):
150
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
151
 
152
  results = [i for i in images]
153
- return [annotated_image] + results
154
 
155
 
156
  with gr.Blocks(theme='gradio/soft') as demo:
 
85
  # STEP 5: Process the classification result. In this case, visualize it.
86
  annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding)
87
  return annotated_image
88
+ with gr.Row():
89
+ model_type = gr.Radio(["Standard", "Hand Encoding"], label="Model preprocessing", info="We developed two models, one with standard mediapipe landmarks, and one with different (but similar) coloring on palm landmards to distinguish left and right")
90
+
91
+ if model_type=="Standard":
92
+ args = Namespace(
93
+ pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
94
+ revision="non-ema",
95
+ from_pt=True,
96
+ controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
97
+ controlnet_revision=None,
98
+ controlnet_from_pt=False,
99
+ )
100
+ if model_type=="Hand Encoding":
101
+ args = Namespace(
102
+ pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
103
+ revision="non-ema",
104
+ from_pt=True,
105
+ controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k",
106
+ controlnet_revision=None,
107
+ controlnet_from_pt=False,
108
+ )
109
 
110
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
111
  args.controlnet_model_name_or_path,
 
138
  prompt_ids = pipeline.prepare_text_inputs(prompts)
139
  prompt_ids = shard(prompt_ids)
140
 
141
+ if model_type=="Standard":
142
+ annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
143
+ overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
144
+ if model_type=="Hand Encoding":
145
+ annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
146
+ overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
147
  validation_image = Image.fromarray(annotated_image).convert("RGB")
148
  processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
149
  processed_image = shard(processed_image)
 
165
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
166
 
167
  results = [i for i in images]
168
+ return [annotated_image, overlap_image] + results
169
 
170
 
171
  with gr.Blocks(theme='gradio/soft') as demo: