Mathis Petrovich commited on
Commit
6a413a4
·
1 Parent(s): 1d3e95c

Changing components (Radio), make examples works

Browse files
Files changed (1) hide show
  1. app.py +180 -117
app.py CHANGED
@@ -10,7 +10,41 @@ from load import load_model, load_json
10
  from load import load_unit_motion_embs_splits, load_keyids_splits
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  EXAMPLES = [
 
14
  "A person is walking in a circle",
15
  "A person is jumping rope",
16
  "Someone is doing a backflip",
@@ -27,25 +61,39 @@ EXAMPLES = [
27
  "A person is taking the stairs",
28
  "Someone is doing jumping jacks",
29
  "The person walked forward and is picking up his toolbox",
30
- "The person angrily punching the air."
31
  ]
32
 
33
  # Show closest text in the training
34
 
35
 
36
  # css to make videos look nice
 
37
  CSS = """
38
  video {
39
  position: relative;
40
  margin: 0;
41
  box-shadow: var(--block-shadow);
42
  border-width: var(--block-border-width);
43
- border-color: var(--block-border-color);
44
  border-radius: var(--block-radius);
45
  background: var(--block-background-fill);
46
  width: 100%;
47
  line-height: var(--line-sm);
48
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
 
51
 
@@ -82,7 +130,8 @@ def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
82
  "end": end,
83
  "text": text,
84
  "keyid": keyid,
85
- "babel_id": babel_id
 
86
  }
87
 
88
  return data
@@ -112,21 +161,33 @@ def retrieve(model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits
112
 
113
 
114
  # HTML component
115
- def get_video_html(url, video_id, start=None, end=None, score=None, width=350, height=350):
116
- trim = ""
117
- if start is not None:
118
- if end is not None:
119
- trim = f"#t={start},{end}"
120
- else:
121
- trim = f"#t={start}"
122
-
123
- score_t = ""
124
- if score is not None:
125
- score_t = f'title="Score = {score}"'
 
 
 
126
 
 
 
 
 
 
 
 
 
 
127
  video_html = f'''
128
- <video preload="auto" muted playsinline onpause="this.load()"
129
- autoplay loop disablepictureinpicture id="{video_id}" width="{width}" height="{height}" {score_t}>
130
  <source src="{url}{trim}" type="video/mp4">
131
  Your browser does not support the video tag.
132
  </video>
@@ -134,132 +195,134 @@ autoplay loop disablepictureinpicture id="{video_id}" width="{width}" height="{h
134
  return video_html
135
 
136
 
137
- def retrive_component(retrieve_function, text, splits, nvids, n_component=16):
138
  # cannot produce more than n_compoenent
139
  nvids = min(nvids, n_component)
140
- if not splits:
141
- return [None for _ in range(n_component)]
142
-
143
- splits_l = [x.lower() for x in splits]
144
- datas = retrieve_function(text, splits=splits_l, nmax=nvids)
145
- htmls = [
146
- get_video_html(
147
- url["url"], idx, start=url["start"],
148
- end=url["end"], score=url["score"]
149
- )
150
- for idx, url in enumerate(datas)
151
- ]
152
  # get n_component exactly if asked less
153
  # pad with dummy blocks
154
  htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
155
  return htmls
156
 
157
 
158
- def main():
159
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
160
 
161
- # LOADING
162
- model = load_model(device)
163
- splits = ["train", "val", "test"]
164
- all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
165
- all_keyids = load_keyids_splits(splits)
166
 
167
- h3d_index = load_json("amass-annotations/humanml3d.json")
168
- amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
169
 
170
- keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
171
- retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
172
 
173
- # DEMO
174
- theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
175
- retrive_and_show = partial(retrive_component, retrieve_function)
 
 
176
 
177
- default_text = "A person is "
 
178
 
179
- with gr.Blocks(css=CSS, theme=theme) as demo:
180
- title = "<h1 style='text-align: center'>TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis </h1>"
181
- gr.Markdown(title)
182
 
183
- authors = """
184
- <h2 style='text-align: center'>
185
- <a href="https://mathis.petrovich.fr" target="_blank"><nobr>Mathis Petrovich</nobr></a> &emsp;
186
- <a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a> &emsp;
187
- <a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>G&uumll Varol</nobr></a>
188
- </h2>
189
- """
190
- gr.Markdown(authors)
191
 
192
- conf = """
193
- <h2 style='text-align: center'>
194
- <nobr>arXiv 2023</nobr>
195
- </h2>
196
- """
197
- gr.Markdown(conf)
198
 
199
- videos = []
 
 
200
 
201
- with gr.Row():
202
- with gr.Column(scale=3):
203
- with gr.Column(scale=2):
204
- text = gr.Textbox(placeholder="Type in natural language, the motion to retrieve",
205
- show_label=True, label="Text prompt", value=default_text)
206
- with gr.Column(scale=1):
207
- btn = gr.Button("Retrieve", variant='primary')
208
- clear = gr.Button("Clear", variant='secondary')
209
-
210
- with gr.Row():
211
- with gr.Column(scale=1):
212
- splits = gr.Dropdown(["Train", "Val", "Test"],
213
- value=["Test"], multiselect=True, label="Splits",
214
- info="HumanML3D data used for the motion database")
215
- with gr.Column(scale=1):
216
- nvideo_slider = gr.Slider(minimum=4, maximum=16, step=4, value=8, label="Number of videos")
217
  with gr.Column(scale=2):
218
- examples = gr.Examples(examples=EXAMPLES, inputs=text, examples_per_page=15)
 
 
 
 
219
 
220
- i = -1
221
- # should indent
222
- for _ in range(4):
223
  with gr.Row():
224
- for _ in range(4):
225
- i += 1
226
- with gr.Column():
227
- video = gr.HTML()
228
- videos.append(video)
229
-
230
- def check_error(splits):
231
- if not splits:
232
- raise gr.Error("At least one split should be selected!")
233
- return splits
234
-
235
- btn.click(fn=retrive_and_show, inputs=[text, splits, nvideo_slider], outputs=videos).then(
236
- fn=check_error, inputs=splits
237
- )
238
 
239
- text.submit(fn=retrive_and_show, inputs=[text, splits, nvideo_slider], outputs=videos).then(
240
- fn=check_error, inputs=splits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
 
 
 
 
242
 
243
- def keep_test(splits):
244
- if len(splits) == 0:
245
- return ["Test"]
246
- return splits
247
-
248
- def clear_videos():
249
- return [None for x in range(16)] + [default_text]
250
-
251
- clear.click(fn=clear_videos, outputs=videos + [text])
252
- demo.launch()
253
-
254
 
255
- def prepare():
256
- if not os.path.exists("data"):
257
- gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08", use_cookies=False)
 
258
 
 
 
259
 
260
- if __name__ == "__main__":
261
- prepare()
262
- main()
263
 
264
- # new
265
- # A person is walking slowly
 
10
  from load import load_unit_motion_embs_splits, load_keyids_splits
11
 
12
 
13
+ WEBSITE = """
14
+
15
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-iYQeCzEYFbKjA/T2uDLTpkwGzCiq6soy8tYaI1GyVh/UjpbCx/TYkiZhlZB6+fzT" crossorigin="anonymous">
16
+ <link href="https://mathis.petrovich.fr/tmr/css/style.css" rel="stylesheet">
17
+ <link href="https://mathis.petrovich.fr/tmr/css/media.css" rel="stylesheet">
18
+
19
+
20
+ <h1 style='text-align: center'>TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis </h1>
21
+
22
+ <h2 style='text-align: center'>
23
+ <a href="https://mathis.petrovich.fr" target="_blank"><nobr>Mathis Petrovich</nobr></a> &emsp;
24
+ <a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a> &emsp;
25
+ <a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>G&uumll Varol</nobr></a>
26
+ </h2>
27
+
28
+ <h2 style='text-align: center'>
29
+ <nobr>arXiv 2023</nobr>
30
+ </h2>
31
+
32
+ <h3 style="text-align:center;">
33
+ <a target="_blank" href="https://arxiv.org/abs/XXXX.XXXXX"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a>
34
+ <a target="_blank" href="https://github.com/Mathux/TMR"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a>
35
+ <a target="_blank" href="https://mathis.petrovich.fr/tmr"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a>
36
+ <a target="_blank" href="https://mathis.petrovich.fr/tmr/tmr.bib"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
37
+ </h3>
38
+
39
+ <h3> Description </h3>
40
+ <p>
41
+ This space illustrates <a href='https://mathis.petrovich.fr/tmr/' target='_blank'><b>TMR</b></a>, a method for text-to-motion retrieval. Given a gallery of 3D human motions (which can be unseen during training) and a text query, the goal is to search for motions which are close to the text query.
42
+ </p>
43
+
44
+ """
45
+
46
  EXAMPLES = [
47
+ "A person is walking slowly",
48
  "A person is walking in a circle",
49
  "A person is jumping rope",
50
  "Someone is doing a backflip",
 
61
  "A person is taking the stairs",
62
  "Someone is doing jumping jacks",
63
  "The person walked forward and is picking up his toolbox",
64
+ "The person angrily punching the air"
65
  ]
66
 
67
  # Show closest text in the training
68
 
69
 
70
  # css to make videos look nice
71
+ # var(--block-border-color);
72
  CSS = """
73
  video {
74
  position: relative;
75
  margin: 0;
76
  box-shadow: var(--block-shadow);
77
  border-width: var(--block-border-width);
78
+ border-color: #000000;
79
  border-radius: var(--block-radius);
80
  background: var(--block-background-fill);
81
  width: 100%;
82
  line-height: var(--line-sm);
83
  }
84
+
85
+ .contour_video {
86
+ display: flex;
87
+ flex-direction: column;
88
+ justify-content: center;
89
+ align-items: center;
90
+ z-index: var(--layer-5);
91
+ border-radius: var(--block-radius);
92
+ background: var(--background-fill-primary);
93
+ padding: 0 var(--size-6);
94
+ max-height: var(--size-screen-h);
95
+ overflow: hidden;
96
+ }
97
  """
98
 
99
 
 
130
  "end": end,
131
  "text": text,
132
  "keyid": keyid,
133
+ "babel_id": babel_id,
134
+ "path": path
135
  }
136
 
137
  return data
 
161
 
162
 
163
  # HTML component
164
+ def get_video_html(data, video_id, width=700, height=700):
165
+ url = data["url"]
166
+ start = data["start"]
167
+ end = data["end"]
168
+ score = data["score"]
169
+ text = data["text"]
170
+ keyid = data["keyid"]
171
+ babel_id = data["babel_id"]
172
+ path = data["path"]
173
+
174
+ trim = f"#t={start},{end}"
175
+ title = f'''Score = {score}
176
+
177
+ Corresponding text: {text}
178
 
179
+ HumanML3D keyid: {keyid}
180
+
181
+ BABEL keyid: {babel_id}
182
+
183
+ AMASS path: {path}'''
184
+
185
+ # class="wrap default svelte-gjihhp hide"
186
+ # <div class="contour_video" style="position: absolute; padding: 10px;">
187
+ # width="{width}" height="{height}"
188
  video_html = f'''
189
+ <video width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
190
+ autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
191
  <source src="{url}{trim}" type="video/mp4">
192
  Your browser does not support the video tag.
193
  </video>
 
195
  return video_html
196
 
197
 
198
+ def retrieve_component(retrieve_function, text, splits_choice, nvids, n_component=32):
199
  # cannot produce more than n_compoenent
200
  nvids = min(nvids, n_component)
201
+
202
+ if "Unseen" in splits_choice:
203
+ splits = ["test"]
204
+ else:
205
+ splits = ["train", "val", "test"]
206
+
207
+ datas = retrieve_function(text, splits=splits, nmax=nvids)
208
+ htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
 
 
 
 
209
  # get n_component exactly if asked less
210
  # pad with dummy blocks
211
  htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
212
  return htmls
213
 
214
 
 
 
215
 
216
+ if not os.path.exists("data"):
217
+ gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
218
+ use_cookies=False)
 
 
219
 
 
 
220
 
221
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
222
 
223
+ # LOADING
224
+ model = load_model(device)
225
+ splits = ["train", "val", "test"]
226
+ all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
227
+ all_keyids = load_keyids_splits(splits)
228
 
229
+ h3d_index = load_json("amass-annotations/humanml3d.json")
230
+ amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
231
 
232
+ keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
233
+ retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
 
234
 
235
+ # DEMO
236
+ theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
237
+ retrieve_and_show = partial(retrieve_component, retrieve_function)
 
 
 
 
 
238
 
239
+ default_text = "A person is "
 
 
 
 
 
240
 
241
+ with gr.Blocks(css=CSS, theme=theme) as demo:
242
+ gr.Markdown(WEBSITE)
243
+ videos = []
244
 
245
+ with gr.Row():
246
+ with gr.Column(scale=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with gr.Column(scale=2):
248
+ text = gr.Textbox(placeholder="Type the motion you want to search with a sentence",
249
+ show_label=True, label="Text prompt", value=default_text)
250
+ with gr.Column(scale=1):
251
+ btn = gr.Button("Retrieve", variant='primary')
252
+ clear = gr.Button("Clear", variant='secondary')
253
 
 
 
 
254
  with gr.Row():
255
+ with gr.Column(scale=1):
256
+ # splits = gr.Dropdown(["Train", "Val", "Test"],
257
+ # value=["Test"], multiselect=True, label="Splits",
258
+ # info="HumanML3D data used for the motion database")
259
+ splits_choice = gr.Radio(["Unseen motions", "All motions"], label="Gallery of motion",
260
+ value="Unseen motions",
261
+ info="The motion gallery is coming from HumanML3D")
 
 
 
 
 
 
 
262
 
263
+ with gr.Column(scale=1):
264
+ # nvideo_slider = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Number of videos")
265
+ nvideo_slider = gr.Radio([4, 8, 12, 16, 24, 28], label="Videos",
266
+ value=8,
267
+ info="Number of videos to display")
268
+
269
+ with gr.Column(scale=2):
270
+ def retrieve_example(text, splits_choice, nvideo_slider):
271
+ return retrieve_and_show(text, splits_choice, nvideo_slider)
272
+
273
+ examples = gr.Examples(examples=[[x, None, None] for x in EXAMPLES],
274
+ inputs=[text, splits_choice, nvideo_slider],
275
+ examples_per_page=20,
276
+ run_on_click=False, cache_examples=False,
277
+ fn=retrieve_example, outputs=[])
278
+
279
+ i = -1
280
+ # should indent
281
+ for _ in range(8):
282
+ with gr.Row():
283
+ for _ in range(4):
284
+ i += 1
285
+ video = gr.HTML()
286
+ videos.append(video)
287
+
288
+ # connect the examples to the output
289
+ # a bit hacky
290
+ examples.outputs = videos
291
+
292
+ def load_example(example_id):
293
+ processed_example = examples.non_none_processed_examples[example_id]
294
+ return gr.utils.resolve_singleton(processed_example)
295
+
296
+ examples.dataset.click(
297
+ load_example,
298
+ inputs=[examples.dataset],
299
+ outputs=examples.inputs_with_examples, # type: ignore
300
+ show_progress=False,
301
+ postprocess=False,
302
+ queue=False,
303
+ ).then(
304
+ fn=retrieve_example,
305
+ inputs=examples.inputs,
306
+ outputs=videos
307
  )
308
+ # def check_error(splits):
309
+ # if not splits:
310
+ # raise gr.Error("At least one split should be selected!")
311
+ # return splits
312
 
313
+ btn.click(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
314
+ #.then(
315
+ # fn=check_error, inputs=splits
316
+ # )
 
 
 
 
 
 
 
317
 
318
+ text.submit(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
319
+ # .then(
320
+ # fn=check_error, inputs=splits
321
+ # )
322
 
323
+ def clear_videos():
324
+ return [None for x in range(32)] + [default_text]
325
 
326
+ clear.click(fn=clear_videos, outputs=videos + [text])
 
 
327
 
328
+ demo.launch()