retir commited on
Commit
f864334
β€’
1 Parent(s): 7bc29a1

add unalign

Browse files
Files changed (2) hide show
  1. app.py +14 -6
  2. inference_pb2.py +6 -6
app.py CHANGED
@@ -82,11 +82,11 @@ def bytes_to_image(image: bytes) -> Image.Image:
82
  return image
83
 
84
 
85
- def edit_image(orig_image, edit_direction, edit_power, align, mask, progress=gr.Progress(track_tqdm=True)):
86
  if edit_direction in DIRECTIONS_NAME_SWAP:
87
  edit_direction = DIRECTIONS_NAME_SWAP[edit_direction]
88
  if not orig_image:
89
- return gr.update(visible=False), gr.update(visible=False), gr.update(value="Need to upload an input image ❗", visible=True)
90
 
91
  orig_image_bytes = get_bytes(orig_image)
92
  mask_bytes = get_bytes(mask)
@@ -103,11 +103,16 @@ def edit_image(orig_image, edit_direction, edit_power, align, mask, progress=gr.
103
  )
104
 
105
  if output.image == b"aligner error":
106
- return gr.update(visible=False), gr.update(visible=False), gr.update(value="Face aligner can not find face in your image 😒 Try to upload another one", visible=True)
107
 
108
  output_edited = bytes_to_image(output.image)
109
  output_inv = bytes_to_image(output.inv_image)
110
- return gr.update(value=output_edited, visible=True), gr.update(value=output_inv, visible=True), gr.update(visible=False)
 
 
 
 
 
111
 
112
 
113
  def edit_image_clip(orig_image, neutral_prompt, target_prompt, disentanglement, edit_power, align, mask, edit_method, progress=gr.Progress(track_tqdm=True)):
@@ -220,6 +225,9 @@ def get_demo():
220
  btn_mask = gr.Button("Generate mask")
221
 
222
  with gr.Column():
 
 
 
223
  with gr.Row():
224
  output_inv = gr.Image(label="Inversion result", visible=True)
225
  output_edit = gr.Image(label="Editing result", visible=True)
@@ -264,12 +272,12 @@ def get_demo():
264
  btn_predef.click(
265
  fn=edit_image,
266
  inputs=[input_image, predef_editing_direction, predef_editing_power, align, mask],
267
- outputs=[output_edit, output_inv, error_message]
268
  )
269
  btn_clip.click(
270
  fn=edit_image_clip,
271
  inputs=[input_image, neutral_prompt, target_prompt, disentanglement, styleclip_editing_power, align, mask, edit_method],
272
- outputs=[output_edit, output_inv, error_message]
273
  )
274
  btn_mask.click(
275
  fn=get_mask,
 
82
  return image
83
 
84
 
85
+ def edit_image(orig_image, edit_direction, edit_power, align, mask, progress=gr.Progress(track_tqdm=True)): # output_align, output_unalign
86
  if edit_direction in DIRECTIONS_NAME_SWAP:
87
  edit_direction = DIRECTIONS_NAME_SWAP[edit_direction]
88
  if not orig_image:
89
+ return gr.update(visible=False), gr.update(visible=False), gr.update(value="Need to upload an input image ❗", visible=True), gr.update(visible=False), gr.update(visible=False)
90
 
91
  orig_image_bytes = get_bytes(orig_image)
92
  mask_bytes = get_bytes(mask)
 
103
  )
104
 
105
  if output.image == b"aligner error":
106
+ return gr.update(visible=False), gr.update(visible=False), gr.update(value="Face aligner can not find face in your image 😒 Try to upload another one", visible=True), gr.update(visible=False), gr.update(visible=False),
107
 
108
  output_edited = bytes_to_image(output.image)
109
  output_inv = bytes_to_image(output.inv_image)
110
+ if not align:
111
+ return gr.update(value=output_edited, visible=True), gr.update(value=output_inv, visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
112
+
113
+ output_aligned = bytes_to_image(output.aligned)
114
+ output_unaligned = bytes_to_image(output.unaligned)
115
+ return gr.update(value=output_edited, visible=True), gr.update(value=output_inv, visible=True), gr.update(visible=False), gr.update(value=output_aligned, visible=True), gr.update(value=output_unaligned, visible=True)
116
 
117
 
118
  def edit_image_clip(orig_image, neutral_prompt, target_prompt, disentanglement, edit_power, align, mask, edit_method, progress=gr.Progress(track_tqdm=True)):
 
225
  btn_mask = gr.Button("Generate mask")
226
 
227
  with gr.Column():
228
+ with gr.Row():
229
+ output_align = gr.Image(label="Alignet original image", visible=True)
230
+ output_unalign = gr.Image(label="Unalinget editing result", visible=True)
231
  with gr.Row():
232
  output_inv = gr.Image(label="Inversion result", visible=True)
233
  output_edit = gr.Image(label="Editing result", visible=True)
 
272
  btn_predef.click(
273
  fn=edit_image,
274
  inputs=[input_image, predef_editing_direction, predef_editing_power, align, mask],
275
+ outputs=[output_edit, output_inv, error_message, output_align, output_unalign]
276
  )
277
  btn_clip.click(
278
  fn=edit_image_clip,
279
  inputs=[input_image, neutral_prompt, target_prompt, disentanglement, styleclip_editing_power, align, mask, edit_method],
280
+ outputs=[output_edit, output_inv, error_message, output_align, output_unalign,]
281
  )
282
  btn_mask.click(
283
  fn=get_mask,
inference_pb2.py CHANGED
@@ -14,7 +14,7 @@ _sym_db = _symbol_database.Default()
14
 
15
 
16
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\"r\n\nSFERequest\x12\x12\n\norig_image\x18\x01 \x01(\x0c\x12\x11\n\tdirection\x18\x02 \x01(\t\x12\r\n\x05power\x18\x03 \x01(\x02\x12\x11\n\tuse_cache\x18\x04 \x01(\x08\x12\r\n\x05\x61lign\x18\x05 \x01(\x08\x12\x0c\n\x04mask\x18\x06 \x01(\x0c\"X\n\x0eSFERequestMask\x12\x12\n\norig_image\x18\x01 \x01(\x0c\x12\x10\n\x08trashold\x18\x02 \x01(\x02\x12\x11\n\tuse_cache\x18\x03 \x01(\x08\x12\r\n\x05\x61lign\x18\x04 \x01(\x08\"/\n\x0bSFEResponse\x12\r\n\x05image\x18\x01 \x01(\x0c\x12\x11\n\tinv_image\x18\x02 \x01(\x0c\"\x1f\n\x0fSFEResponseMask\x12\x0c\n\x04mask\x18\x01 \x01(\x0c\x32\x8b\x01\n\nSFEService\x12\x35\n\x04\x65\x64it\x12\x15.inference.SFERequest\x1a\x16.inference.SFEResponse\x12\x46\n\rgenerate_mask\x12\x19.inference.SFERequestMask\x1a\x1a.inference.SFEResponseMaskb\x06proto3')
18
 
19
  _globals = globals()
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -26,9 +26,9 @@ if not _descriptor._USE_C_DESCRIPTORS:
26
  _globals['_SFEREQUESTMASK']._serialized_start=146
27
  _globals['_SFEREQUESTMASK']._serialized_end=234
28
  _globals['_SFERESPONSE']._serialized_start=236
29
- _globals['_SFERESPONSE']._serialized_end=283
30
- _globals['_SFERESPONSEMASK']._serialized_start=285
31
- _globals['_SFERESPONSEMASK']._serialized_end=316
32
- _globals['_SFESERVICE']._serialized_start=319
33
- _globals['_SFESERVICE']._serialized_end=458
34
  # @@protoc_insertion_point(module_scope)
 
14
 
15
 
16
 
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\"r\n\nSFERequest\x12\x12\n\norig_image\x18\x01 \x01(\x0c\x12\x11\n\tdirection\x18\x02 \x01(\t\x12\r\n\x05power\x18\x03 \x01(\x02\x12\x11\n\tuse_cache\x18\x04 \x01(\x08\x12\r\n\x05\x61lign\x18\x05 \x01(\x08\x12\x0c\n\x04mask\x18\x06 \x01(\x0c\"X\n\x0eSFERequestMask\x12\x12\n\norig_image\x18\x01 \x01(\x0c\x12\x10\n\x08trashold\x18\x02 \x01(\x02\x12\x11\n\tuse_cache\x18\x03 \x01(\x08\x12\r\n\x05\x61lign\x18\x04 \x01(\x08\"S\n\x0bSFEResponse\x12\r\n\x05image\x18\x01 \x01(\x0c\x12\x11\n\tinv_image\x18\x02 \x01(\x0c\x12\x0f\n\x07\x61ligned\x18\x03 \x01(\x0c\x12\x11\n\tunaligned\x18\x04 \x01(\x0c\"\x1f\n\x0fSFEResponseMask\x12\x0c\n\x04mask\x18\x01 \x01(\x0c\x32\x8b\x01\n\nSFEService\x12\x35\n\x04\x65\x64it\x12\x15.inference.SFERequest\x1a\x16.inference.SFEResponse\x12\x46\n\rgenerate_mask\x12\x19.inference.SFERequestMask\x1a\x1a.inference.SFEResponseMaskb\x06proto3')
18
 
19
  _globals = globals()
20
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
 
26
  _globals['_SFEREQUESTMASK']._serialized_start=146
27
  _globals['_SFEREQUESTMASK']._serialized_end=234
28
  _globals['_SFERESPONSE']._serialized_start=236
29
+ _globals['_SFERESPONSE']._serialized_end=319
30
+ _globals['_SFERESPONSEMASK']._serialized_start=321
31
+ _globals['_SFERESPONSEMASK']._serialized_end=352
32
+ _globals['_SFESERVICE']._serialized_start=355
33
+ _globals['_SFESERVICE']._serialized_end=494
34
  # @@protoc_insertion_point(module_scope)