Spaces:
Runtime error
Runtime error
File size: 18,071 Bytes
b2c8596 a0cc202 eb4334e eac8477 eb4334e b2c8596 a0cc202 b2c8596 a0cc202 b2c8596 a0cc202 eb4334e a0cc202 eb4334e a0cc202 eb4334e a0cc202 eb4334e 58ac711 6d21cff 58ac711 6d21cff 58ac711 b2c8596 6d21cff 576f095 b2c8596 6d21cff b2c8596 6d21cff b2c8596 576f095 6d21cff b2c8596 6d21cff b2c8596 6d21cff b2c8596 6d21cff b2c8596 bf39427 b2c8596 6d21cff 6c4b724 b2c8596 6d21cff eafe433 6d21cff eafe433 b2c8596 58ac711 eafe433 58ac711 eafe433 eb4334e b2c8596 6d21cff eafe433 6d21cff b2c8596 6d21cff b2c8596 eafe433 b2c8596 eb4334e eac8477 b2c8596 8ff21a8 0d3cbf7 b4c6758 0d3cbf7 b4c6758 0d3cbf7 b4c6758 0d3cbf7 6e3314d acae947 6e3314d 0d3cbf7 9e5b250 bd81c24 e34dfbf 4ab7300 e34dfbf eac8477 e34dfbf 117ba37 f840f8e c2802dd 2c674b9 4ab7300 eac8477 2c674b9 4ab7300 c2802dd eac8477 b2c8596 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
import jax
import jax.numpy as jnp
from flax import jax_utils
from flax.training.common_utils import shard
from PIL import Image
from argparse import Namespace
import gradio as gr
import copy # added
import numpy as np
import mediapipe as mp
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
import psutil
from gpuinfo import GPUInfo
import time
import gc
import torch
from diffusers import (
FlaxControlNetModel,
FlaxStableDiffusionControlNetPipeline,
)
right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
right_style_lm[0].color=(251, 206, 177)
left_style_lm[0].color=(255, 255, 225)
def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False):
hand_landmarks_list = detection_result.hand_landmarks
handedness_list = detection_result.handedness
if overlap:
annotated_image = np.copy(rgb_image)
else:
annotated_image = np.zeros_like(rgb_image)
# Loop through the detected hands to visualize.
for idx in range(len(hand_landmarks_list)):
hand_landmarks = hand_landmarks_list[idx]
handedness = handedness_list[idx]
# Draw the hand landmarks.
hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
hand_landmarks_proto.landmark.extend([
landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
])
if hand_encoding:
if handedness[0].category_name == "Left":
solutions.drawing_utils.draw_landmarks(
annotated_image,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
left_style_lm,
solutions.drawing_styles.get_default_hand_connections_style())
if handedness[0].category_name == "Right":
solutions.drawing_utils.draw_landmarks(
annotated_image,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
right_style_lm,
solutions.drawing_styles.get_default_hand_connections_style())
else:
solutions.drawing_utils.draw_landmarks(
annotated_image,
hand_landmarks_proto,
solutions.hands.HAND_CONNECTIONS,
solutions.drawing_styles.get_default_hand_landmarks_style(),
solutions.drawing_styles.get_default_hand_connections_style())
return annotated_image
def generate_annotation(img, overlap=False, hand_encoding=False):
"""img(input): numpy array
annotated_image(output): numpy array
"""
# STEP 2: Create an HandLandmarker object.
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
options = vision.HandLandmarkerOptions(base_options=base_options,
num_hands=2)
detector = vision.HandLandmarker.create_from_options(options)
# STEP 3: Load the input image.
image = mp.Image(
image_format=mp.ImageFormat.SRGB, data=img)
# STEP 4: Detect hand landmarks from the input image.
detection_result = detector.detect(image)
# STEP 5: Process the classification result. In this case, visualize it.
annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding)
return annotated_image
std_args = Namespace(
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
revision="non-ema",
from_pt=True,
controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
controlnet_revision=None,
controlnet_from_pt=False,
)
enc_args = Namespace(
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
revision="non-ema",
from_pt=True,
controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k",
controlnet_revision=None,
controlnet_from_pt=False,
)
std_controlnet, std_controlnet_params = FlaxControlNetModel.from_pretrained(
std_args.controlnet_model_name_or_path,
revision=std_args.controlnet_revision,
from_pt=std_args.controlnet_from_pt,
dtype=jnp.float32, # jnp.bfloat16
)
enc_controlnet, enc_controlnet_params = FlaxControlNetModel.from_pretrained(
enc_args.controlnet_model_name_or_path,
revision=enc_args.controlnet_revision,
from_pt=enc_args.controlnet_from_pt,
dtype=jnp.float32, # jnp.bfloat16
)
std_pipeline, std_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
std_args.pretrained_model_name_or_path,
# tokenizer=tokenizer,
controlnet=std_controlnet,
safety_checker=None,
dtype=jnp.float32, # jnp.bfloat16
revision=std_args.revision,
from_pt=std_args.from_pt,
)
enc_pipeline, enc_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
enc_args.pretrained_model_name_or_path,
# tokenizer=tokenizer,
controlnet=enc_controlnet,
safety_checker=None,
dtype=jnp.float32, # jnp.bfloat16
revision=enc_args.revision,
from_pt=enc_args.from_pt,
)
std_pipeline_params["controlnet"] = std_controlnet_params
std_pipeline_params = jax_utils.replicate(std_pipeline_params)
enc_pipeline_params["controlnet"] = enc_controlnet_params
enc_pipeline_params = jax_utils.replicate(enc_pipeline_params)
rng = jax.random.PRNGKey(0)
num_samples = jax.device_count()
prng_seed = jax.random.split(rng, jax.device_count())
memory = psutil.virtual_memory()
def infer(prompt, negative_prompt, image, model_type="Standard"):
time_start = time.time()
prompts = num_samples * [prompt]
if model_type=="Standard":
prompt_ids = std_pipeline.prepare_text_inputs(prompts)
elif model_type=="Hand Encoding":
prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
else:
pass
prompt_ids = shard(prompt_ids)
if model_type=="Standard":
annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
elif model_type=="Hand Encoding":
annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
else:
pass
validation_image = Image.fromarray(annotated_image).convert("RGB")
if model_type=="Standard":
processed_image = std_pipeline.prepare_image_inputs(num_samples * [validation_image])
processed_image = shard(processed_image)
negative_prompt_ids = std_pipeline.prepare_text_inputs([negative_prompt] * num_samples)
negative_prompt_ids = shard(negative_prompt_ids)
images = std_pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=std_pipeline_params,
prng_seed=prng_seed,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
elif model_type=="Hand Encoding":
processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
processed_image = shard(processed_image)
negative_prompt_ids = enc_pipeline.prepare_text_inputs([negative_prompt] * num_samples)
negative_prompt_ids = shard(negative_prompt_ids)
images = enc_pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=enc_pipeline_params,
prng_seed=prng_seed,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
else:
pass
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
results = [i for i in images]
# running info
time_end = time.time()
time_diff = time_end - time_start
gc.collect()
torch.cuda.empty_cache()
memory = psutil.virtual_memory()
gpu_utilization, gpu_memory = GPUInfo.gpu_usage()
gpu_utilization = gpu_utilization[0] if len(gpu_utilization) > 0 else 0
gpu_memory = gpu_memory[0] if len(gpu_memory) > 0 else 0
system_info = f"""
*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB.*
*Processing time: {time_diff:.5} seconds.*
*GPU Utilization: {gpu_utilization}%, GPU Memory: {gpu_memory}MiB.*
"""
return [overlap_image, annotated_image] + results, system_info
with gr.Blocks(theme='gradio/soft') as demo:
gr.Markdown("## Stable Diffusion with Hand Control")
gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.")
with gr.Box():
gr.Markdown("""<h2><b>Summary π</b></h2>""")
with gr.Accordion("Detail information", open=False):
gr.Markdown("""
As Stable diffusion and other diffusion models are notoriously poor at generating realistic hands for our project we decided to train a ControlNet model using MediaPipes landmarks in order to generate more realistic hands avoiding common issues such as unrealistic positions and irregular digits.
<br>
We opted to use the [HAnd Gesture Recognition Image Dataset](https://github.com/hukenovs/hagrid) (HaGRID) and [MediaPipe's Hand Landmarker](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker) to train a control net that could potentially be used independently or as an in-painting tool.
To preprocess the data there were three options we considered:
<ul>
<li>The first was to use Mediapipes built-in draw landmarks function. This was an obvious first choice however we noticed with low training steps that the model couldn't easily distinguish handedness and would often generate the wrong hand for the conditioning image.</li>
<center>
<table><tr>
<td>
<p align="center" style="padding: 10px">
<img alt="Forwarding" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid250k-blip2/--/MakiPan--hagrid250k-blip2/train/29/image/image.jpg" width="200">
<br>
<em style="color: grey">Original Image</em>
</p>
</td>
<td>
<p align="center">
<img alt="Routing" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid250k-blip2/--/MakiPan--hagrid250k-blip2/train/29/conditioning_image/image.jpg" width="200">
<br>
<em style="color: grey">Conditioning Image</em>
</p>
</td>
</tr></table>
</center>
<li>To counter this issue we changed the palm landmark colors with the intention to keep the color similar in order to learn that they provide similar information, but different to make the model know which hands were left or right.</li>
<center>
<table><tr>
<td>
<p align="center" style="padding: 10px">
<img alt="Forwarding" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid-hand-enc-250k/--/MakiPan--hagrid-hand-enc-250k/train/96/image/image.jpg" width="200">
<br>
<em style="color: grey">Original Image</em>
</p>
</td>
<td>
<p align="center">
<img alt="Routing" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid-hand-enc-250k/--/MakiPan--hagrid-hand-enc-250k/train/96/conditioning_image/image.jpg" width="200">
<br>
<em style="color: grey">Conditioning Image</em>
</p>
</td>
</tr></table>
</center>
<li>The last option was to use <a href="https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html">MediaPipe Holistic</a> to provide pose face and hand landmarks to the ControlNet. This method was promising in theory, however, the HaGRID dataset was not suitable for this method as the Holistic model performs poorly with partial body and obscurely cropped images.</li>
</ul>
We anecdotally determined that when trained at lower steps the encoded hand model performed better than the standard MediaPipe model due to implied handedness. We theorize that with a larger dataset of more full-body hand and pose classifications, Holistic landmarks will provide the best images in the future however for the moment the hand-encoded model performs best.
""")
# Information links
with gr.Box():
gr.Markdown("""<h2><b>Links π</b></h2>""")
with gr.Accordion("Models π", open=False):
gr.Markdown("""
<h4><a href="https://huggingface.co./Vincent-luo/controlnet-hands">Standard Model</a></h4>
<h4> <a href="https://huggingface.co./MakiPan/controlnet-encoded-hands-130k/">Model using Hand Encoding</a></h4>
""")
with gr.Accordion("Datasets πΎ", open=False):
gr.Markdown("""
<h4> <a href="https://huggingface.co./datasets/MakiPan/hagrid250k-blip2">Dataset for Standard Model</a></h4>
<h4> <a href="https://huggingface.co./datasets/MakiPan/hagrid-hand-enc-250k">Dataset for Hand Encoding Model</a></h4>
""")
with gr.Accordion("Preprocessing Scripts π", open=False):
gr.Markdown("""
<h4> <a href="https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py">Standard Data Preprocessing Script</a></h4>
<h4> <a href="https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py">Hand Encoding Data Preprocessing Script</a></h4></center>
""")
# How to use model
with gr.Box():
gr.Markdown("""<h2><b>How to use βοΈ</b></h2>""")
with gr.Accordion("Generate image with ControlnetHand", open=True):
gr.Markdown("""
- Step 1. Select preprocessing method (Standard or Hand encoding)
- Step 2. Describe the image you want to create along with the hand details of the uploaded or captured image
- Step 3. Provide a negative prompt that helps the model not to create redundant details
- Step 4. Upload or capture by webcam a clear image of hands that are prominently visible in the foreground
- Step 5. Submit and enjoy
""")
# Model input parameters
model_type = gr.Radio(["Standard", "Hand Encoding"], value="Standard", label="Model preprocessing", info="We developed two models, one with standard MediaPipe landmarks, and one with different (but similar) coloring on palm landmarks to distinguish left and right")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt")
negative_prompt = gr.Textbox(label="Negative Prompt")
with gr.Box():
with gr.Tab("Upload Image"):
upload_image = gr.Image(label="Upload Image", source="upload")
with gr.Tab("Webcam"):
webcam_image = gr.Image(label="Webcam", source="webcam")
# output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto')
submit_btn = gr.Button(value = "Submit")
# inputs = [prompt_input, negative_prompt, input_image]
# submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
system_info = gr.Markdown(f"*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*")
with gr.Column():
output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto')
gr.Examples(
examples=[
[
"a woman is making an ok sign in front of a painting",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example.png"
],
[
"a man with his hands up in the air making a rock sign",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example1.png"
],
[
"a man is making a thumbs up gesture",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example2.png"
],
[
"a woman is holding up her hand in front of a window",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example3.png"
],
[
"a man with his finger on his lips",
"longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
"example4.png"
],
],
inputs=[prompt_input, negative_prompt, upload_image, model_type],
outputs=[output_image, system_info],
fn=infer,
cache_examples=True,
)
# check source of image
if upload_image and webcam_image is None:
input_image = upload_image
else:
input_image = webcam_image
inputs = [prompt_input, negative_prompt, input_image, model_type]
submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image, system_info])
demo.launch() |