Spaces:
Running
Running
Mathis Petrovich
commited on
Commit
·
6a413a4
1
Parent(s):
1d3e95c
Changing components (Radio), make examples works
Browse files
app.py
CHANGED
@@ -10,7 +10,41 @@ from load import load_model, load_json
|
|
10 |
from load import load_unit_motion_embs_splits, load_keyids_splits
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
EXAMPLES = [
|
|
|
14 |
"A person is walking in a circle",
|
15 |
"A person is jumping rope",
|
16 |
"Someone is doing a backflip",
|
@@ -27,25 +61,39 @@ EXAMPLES = [
|
|
27 |
"A person is taking the stairs",
|
28 |
"Someone is doing jumping jacks",
|
29 |
"The person walked forward and is picking up his toolbox",
|
30 |
-
"The person angrily punching the air
|
31 |
]
|
32 |
|
33 |
# Show closest text in the training
|
34 |
|
35 |
|
36 |
# css to make videos look nice
|
|
|
37 |
CSS = """
|
38 |
video {
|
39 |
position: relative;
|
40 |
margin: 0;
|
41 |
box-shadow: var(--block-shadow);
|
42 |
border-width: var(--block-border-width);
|
43 |
-
border-color:
|
44 |
border-radius: var(--block-radius);
|
45 |
background: var(--block-background-fill);
|
46 |
width: 100%;
|
47 |
line-height: var(--line-sm);
|
48 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
"""
|
50 |
|
51 |
|
@@ -82,7 +130,8 @@ def humanml3d_keyid_to_babel_rendered_url(h3d_index, amass_to_babel, keyid):
|
|
82 |
"end": end,
|
83 |
"text": text,
|
84 |
"keyid": keyid,
|
85 |
-
"babel_id": babel_id
|
|
|
86 |
}
|
87 |
|
88 |
return data
|
@@ -112,21 +161,33 @@ def retrieve(model, keyid_to_url, all_unit_motion_embs, all_keyids, text, splits
|
|
112 |
|
113 |
|
114 |
# HTML component
|
115 |
-
def get_video_html(
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
video_html = f'''
|
128 |
-
<video preload="auto" muted playsinline onpause="this.load()"
|
129 |
-
autoplay loop disablepictureinpicture id="{video_id}"
|
130 |
<source src="{url}{trim}" type="video/mp4">
|
131 |
Your browser does not support the video tag.
|
132 |
</video>
|
@@ -134,132 +195,134 @@ autoplay loop disablepictureinpicture id="{video_id}" width="{width}" height="{h
|
|
134 |
return video_html
|
135 |
|
136 |
|
137 |
-
def
|
138 |
# cannot produce more than n_compoenent
|
139 |
nvids = min(nvids, n_component)
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
end=url["end"], score=url["score"]
|
149 |
-
)
|
150 |
-
for idx, url in enumerate(datas)
|
151 |
-
]
|
152 |
# get n_component exactly if asked less
|
153 |
# pad with dummy blocks
|
154 |
htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
|
155 |
return htmls
|
156 |
|
157 |
|
158 |
-
def main():
|
159 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
|
165 |
-
all_keyids = load_keyids_splits(splits)
|
166 |
|
167 |
-
h3d_index = load_json("amass-annotations/humanml3d.json")
|
168 |
-
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
169 |
|
170 |
-
|
171 |
-
retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
176 |
|
177 |
-
|
|
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
gr.Markdown(title)
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
<a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a>  
|
187 |
-
<a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>Gül Varol</nobr></a>
|
188 |
-
</h2>
|
189 |
-
"""
|
190 |
-
gr.Markdown(authors)
|
191 |
|
192 |
-
|
193 |
-
<h2 style='text-align: center'>
|
194 |
-
<nobr>arXiv 2023</nobr>
|
195 |
-
</h2>
|
196 |
-
"""
|
197 |
-
gr.Markdown(conf)
|
198 |
|
199 |
-
|
|
|
|
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
with gr.Column(scale=2):
|
204 |
-
text = gr.Textbox(placeholder="Type in natural language, the motion to retrieve",
|
205 |
-
show_label=True, label="Text prompt", value=default_text)
|
206 |
-
with gr.Column(scale=1):
|
207 |
-
btn = gr.Button("Retrieve", variant='primary')
|
208 |
-
clear = gr.Button("Clear", variant='secondary')
|
209 |
-
|
210 |
-
with gr.Row():
|
211 |
-
with gr.Column(scale=1):
|
212 |
-
splits = gr.Dropdown(["Train", "Val", "Test"],
|
213 |
-
value=["Test"], multiselect=True, label="Splits",
|
214 |
-
info="HumanML3D data used for the motion database")
|
215 |
-
with gr.Column(scale=1):
|
216 |
-
nvideo_slider = gr.Slider(minimum=4, maximum=16, step=4, value=8, label="Number of videos")
|
217 |
with gr.Column(scale=2):
|
218 |
-
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
i = -1
|
221 |
-
# should indent
|
222 |
-
for _ in range(4):
|
223 |
with gr.Row():
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
if not splits:
|
232 |
-
raise gr.Error("At least one split should be selected!")
|
233 |
-
return splits
|
234 |
-
|
235 |
-
btn.click(fn=retrive_and_show, inputs=[text, splits, nvideo_slider], outputs=videos).then(
|
236 |
-
fn=check_error, inputs=splits
|
237 |
-
)
|
238 |
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
)
|
|
|
|
|
|
|
|
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
def clear_videos():
|
249 |
-
return [None for x in range(16)] + [default_text]
|
250 |
-
|
251 |
-
clear.click(fn=clear_videos, outputs=videos + [text])
|
252 |
-
demo.launch()
|
253 |
-
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
258 |
|
|
|
|
|
259 |
|
260 |
-
|
261 |
-
prepare()
|
262 |
-
main()
|
263 |
|
264 |
-
|
265 |
-
# A person is walking slowly
|
|
|
10 |
from load import load_unit_motion_embs_splits, load_keyids_splits
|
11 |
|
12 |
|
13 |
+
WEBSITE = """
|
14 |
+
|
15 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-iYQeCzEYFbKjA/T2uDLTpkwGzCiq6soy8tYaI1GyVh/UjpbCx/TYkiZhlZB6+fzT" crossorigin="anonymous">
|
16 |
+
<link href="https://mathis.petrovich.fr/tmr/css/style.css" rel="stylesheet">
|
17 |
+
<link href="https://mathis.petrovich.fr/tmr/css/media.css" rel="stylesheet">
|
18 |
+
|
19 |
+
|
20 |
+
<h1 style='text-align: center'>TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis </h1>
|
21 |
+
|
22 |
+
<h2 style='text-align: center'>
|
23 |
+
<a href="https://mathis.petrovich.fr" target="_blank"><nobr>Mathis Petrovich</nobr></a>  
|
24 |
+
<a href="https://ps.is.mpg.de/~black" target="_blank"><nobr>Michael J. Black</nobr></a>  
|
25 |
+
<a href="https://imagine.enpc.fr/~varolg" target="_blank"><nobr>Gül Varol</nobr></a>
|
26 |
+
</h2>
|
27 |
+
|
28 |
+
<h2 style='text-align: center'>
|
29 |
+
<nobr>arXiv 2023</nobr>
|
30 |
+
</h2>
|
31 |
+
|
32 |
+
<h3 style="text-align:center;">
|
33 |
+
<a target="_blank" href="https://arxiv.org/abs/XXXX.XXXXX"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a>
|
34 |
+
<a target="_blank" href="https://github.com/Mathux/TMR"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a>
|
35 |
+
<a target="_blank" href="https://mathis.petrovich.fr/tmr"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a>
|
36 |
+
<a target="_blank" href="https://mathis.petrovich.fr/tmr/tmr.bib"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
|
37 |
+
</h3>
|
38 |
+
|
39 |
+
<h3> Description </h3>
|
40 |
+
<p>
|
41 |
+
This space illustrates <a href='https://mathis.petrovich.fr/tmr/' target='_blank'><b>TMR</b></a>, a method for text-to-motion retrieval. Given a gallery of 3D human motions (which can be unseen during training) and a text query, the goal is to search for motions which are close to the text query.
|
42 |
+
</p>
|
43 |
+
|
44 |
+
"""
|
45 |
+
|
46 |
EXAMPLES = [
|
47 |
+
"A person is walking slowly",
|
48 |
"A person is walking in a circle",
|
49 |
"A person is jumping rope",
|
50 |
"Someone is doing a backflip",
|
|
|
61 |
"A person is taking the stairs",
|
62 |
"Someone is doing jumping jacks",
|
63 |
"The person walked forward and is picking up his toolbox",
|
64 |
+
"The person angrily punching the air"
|
65 |
]
|
66 |
|
67 |
# Show closest text in the training
|
68 |
|
69 |
|
70 |
# css to make videos look nice
|
71 |
+
# var(--block-border-color);
|
72 |
CSS = """
|
73 |
video {
|
74 |
position: relative;
|
75 |
margin: 0;
|
76 |
box-shadow: var(--block-shadow);
|
77 |
border-width: var(--block-border-width);
|
78 |
+
border-color: #000000;
|
79 |
border-radius: var(--block-radius);
|
80 |
background: var(--block-background-fill);
|
81 |
width: 100%;
|
82 |
line-height: var(--line-sm);
|
83 |
}
|
84 |
+
|
85 |
+
.contour_video {
|
86 |
+
display: flex;
|
87 |
+
flex-direction: column;
|
88 |
+
justify-content: center;
|
89 |
+
align-items: center;
|
90 |
+
z-index: var(--layer-5);
|
91 |
+
border-radius: var(--block-radius);
|
92 |
+
background: var(--background-fill-primary);
|
93 |
+
padding: 0 var(--size-6);
|
94 |
+
max-height: var(--size-screen-h);
|
95 |
+
overflow: hidden;
|
96 |
+
}
|
97 |
"""
|
98 |
|
99 |
|
|
|
130 |
"end": end,
|
131 |
"text": text,
|
132 |
"keyid": keyid,
|
133 |
+
"babel_id": babel_id,
|
134 |
+
"path": path
|
135 |
}
|
136 |
|
137 |
return data
|
|
|
161 |
|
162 |
|
163 |
# HTML component
|
164 |
+
def get_video_html(data, video_id, width=700, height=700):
|
165 |
+
url = data["url"]
|
166 |
+
start = data["start"]
|
167 |
+
end = data["end"]
|
168 |
+
score = data["score"]
|
169 |
+
text = data["text"]
|
170 |
+
keyid = data["keyid"]
|
171 |
+
babel_id = data["babel_id"]
|
172 |
+
path = data["path"]
|
173 |
+
|
174 |
+
trim = f"#t={start},{end}"
|
175 |
+
title = f'''Score = {score}
|
176 |
+
|
177 |
+
Corresponding text: {text}
|
178 |
|
179 |
+
HumanML3D keyid: {keyid}
|
180 |
+
|
181 |
+
BABEL keyid: {babel_id}
|
182 |
+
|
183 |
+
AMASS path: {path}'''
|
184 |
+
|
185 |
+
# class="wrap default svelte-gjihhp hide"
|
186 |
+
# <div class="contour_video" style="position: absolute; padding: 10px;">
|
187 |
+
# width="{width}" height="{height}"
|
188 |
video_html = f'''
|
189 |
+
<video width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
190 |
+
autoplay loop disablepictureinpicture id="{video_id}" title="{title}">
|
191 |
<source src="{url}{trim}" type="video/mp4">
|
192 |
Your browser does not support the video tag.
|
193 |
</video>
|
|
|
195 |
return video_html
|
196 |
|
197 |
|
198 |
+
def retrieve_component(retrieve_function, text, splits_choice, nvids, n_component=32):
|
199 |
# cannot produce more than n_compoenent
|
200 |
nvids = min(nvids, n_component)
|
201 |
+
|
202 |
+
if "Unseen" in splits_choice:
|
203 |
+
splits = ["test"]
|
204 |
+
else:
|
205 |
+
splits = ["train", "val", "test"]
|
206 |
+
|
207 |
+
datas = retrieve_function(text, splits=splits, nmax=nvids)
|
208 |
+
htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
|
|
|
|
|
|
|
|
|
209 |
# get n_component exactly if asked less
|
210 |
# pad with dummy blocks
|
211 |
htmls = htmls + [None for _ in range(max(0, n_component-nvids))]
|
212 |
return htmls
|
213 |
|
214 |
|
|
|
|
|
215 |
|
216 |
+
if not os.path.exists("data"):
|
217 |
+
gdown.download_folder("https://drive.google.com/drive/folders/1MgPFgHZ28AMd01M1tJ7YW_1-ut3-4j08",
|
218 |
+
use_cookies=False)
|
|
|
|
|
219 |
|
|
|
|
|
220 |
|
221 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
222 |
|
223 |
+
# LOADING
|
224 |
+
model = load_model(device)
|
225 |
+
splits = ["train", "val", "test"]
|
226 |
+
all_unit_motion_embs = load_unit_motion_embs_splits(splits, device)
|
227 |
+
all_keyids = load_keyids_splits(splits)
|
228 |
|
229 |
+
h3d_index = load_json("amass-annotations/humanml3d.json")
|
230 |
+
amass_to_babel = load_json("amass-annotations/amass_to_babel.json")
|
231 |
|
232 |
+
keyid_to_url = partial(humanml3d_keyid_to_babel_rendered_url, h3d_index, amass_to_babel)
|
233 |
+
retrieve_function = partial(retrieve, model, keyid_to_url, all_unit_motion_embs, all_keyids)
|
|
|
234 |
|
235 |
+
# DEMO
|
236 |
+
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
237 |
+
retrieve_and_show = partial(retrieve_component, retrieve_function)
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
+
default_text = "A person is "
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
+
with gr.Blocks(css=CSS, theme=theme) as demo:
|
242 |
+
gr.Markdown(WEBSITE)
|
243 |
+
videos = []
|
244 |
|
245 |
+
with gr.Row():
|
246 |
+
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
with gr.Column(scale=2):
|
248 |
+
text = gr.Textbox(placeholder="Type the motion you want to search with a sentence",
|
249 |
+
show_label=True, label="Text prompt", value=default_text)
|
250 |
+
with gr.Column(scale=1):
|
251 |
+
btn = gr.Button("Retrieve", variant='primary')
|
252 |
+
clear = gr.Button("Clear", variant='secondary')
|
253 |
|
|
|
|
|
|
|
254 |
with gr.Row():
|
255 |
+
with gr.Column(scale=1):
|
256 |
+
# splits = gr.Dropdown(["Train", "Val", "Test"],
|
257 |
+
# value=["Test"], multiselect=True, label="Splits",
|
258 |
+
# info="HumanML3D data used for the motion database")
|
259 |
+
splits_choice = gr.Radio(["Unseen motions", "All motions"], label="Gallery of motion",
|
260 |
+
value="Unseen motions",
|
261 |
+
info="The motion gallery is coming from HumanML3D")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
+
with gr.Column(scale=1):
|
264 |
+
# nvideo_slider = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Number of videos")
|
265 |
+
nvideo_slider = gr.Radio([4, 8, 12, 16, 24, 28], label="Videos",
|
266 |
+
value=8,
|
267 |
+
info="Number of videos to display")
|
268 |
+
|
269 |
+
with gr.Column(scale=2):
|
270 |
+
def retrieve_example(text, splits_choice, nvideo_slider):
|
271 |
+
return retrieve_and_show(text, splits_choice, nvideo_slider)
|
272 |
+
|
273 |
+
examples = gr.Examples(examples=[[x, None, None] for x in EXAMPLES],
|
274 |
+
inputs=[text, splits_choice, nvideo_slider],
|
275 |
+
examples_per_page=20,
|
276 |
+
run_on_click=False, cache_examples=False,
|
277 |
+
fn=retrieve_example, outputs=[])
|
278 |
+
|
279 |
+
i = -1
|
280 |
+
# should indent
|
281 |
+
for _ in range(8):
|
282 |
+
with gr.Row():
|
283 |
+
for _ in range(4):
|
284 |
+
i += 1
|
285 |
+
video = gr.HTML()
|
286 |
+
videos.append(video)
|
287 |
+
|
288 |
+
# connect the examples to the output
|
289 |
+
# a bit hacky
|
290 |
+
examples.outputs = videos
|
291 |
+
|
292 |
+
def load_example(example_id):
|
293 |
+
processed_example = examples.non_none_processed_examples[example_id]
|
294 |
+
return gr.utils.resolve_singleton(processed_example)
|
295 |
+
|
296 |
+
examples.dataset.click(
|
297 |
+
load_example,
|
298 |
+
inputs=[examples.dataset],
|
299 |
+
outputs=examples.inputs_with_examples, # type: ignore
|
300 |
+
show_progress=False,
|
301 |
+
postprocess=False,
|
302 |
+
queue=False,
|
303 |
+
).then(
|
304 |
+
fn=retrieve_example,
|
305 |
+
inputs=examples.inputs,
|
306 |
+
outputs=videos
|
307 |
)
|
308 |
+
# def check_error(splits):
|
309 |
+
# if not splits:
|
310 |
+
# raise gr.Error("At least one split should be selected!")
|
311 |
+
# return splits
|
312 |
|
313 |
+
btn.click(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
|
314 |
+
#.then(
|
315 |
+
# fn=check_error, inputs=splits
|
316 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
+
text.submit(fn=retrieve_and_show, inputs=[text, splits_choice, nvideo_slider], outputs=videos)
|
319 |
+
# .then(
|
320 |
+
# fn=check_error, inputs=splits
|
321 |
+
# )
|
322 |
|
323 |
+
def clear_videos():
|
324 |
+
return [None for x in range(32)] + [default_text]
|
325 |
|
326 |
+
clear.click(fn=clear_videos, outputs=videos + [text])
|
|
|
|
|
327 |
|
328 |
+
demo.launch()
|
|