Upload 2 files
Browse files- app.py +3 -3
- hfsearch.py +134 -46
app.py
CHANGED
@@ -17,10 +17,10 @@ with gr.Blocks(theme="NoCrypt/miku", fill_width=True, css=CSS) as demo:
|
|
17 |
with gr.Tab("Normal Search"):
|
18 |
with gr.Group():
|
19 |
with gr.Row(equal_height=True):
|
20 |
-
repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space"], value=["model", "dataset", "space"])
|
|
|
21 |
with gr.Accordion("Advanced", open=False):
|
22 |
with gr.Row(equal_height=True):
|
23 |
-
filter_str = gr.Textbox(label="Filter", info="String(s) to filter repos", value="")
|
24 |
search_str = gr.Textbox(label="Search", info="A string that will be contained in the returned repo ids", placeholder="bert", value="", lines=1)
|
25 |
author = gr.Textbox(label="Author", info="The author (user or organization)", value="", lines=1)
|
26 |
with gr.Column():
|
@@ -79,7 +79,7 @@ with gr.Blocks(theme="NoCrypt/miku", fill_width=True, css=CSS) as demo:
|
|
79 |
#rec_repo_id = gr.Textbox(label="Repo ID", info="Input your favorite repo", value="")
|
80 |
rec_repo_id = HuggingfaceHubSearch(label="Repo ID", placeholder="Input your favorite Repo ID", search_type=["model", "dataset", "space"],
|
81 |
sumbit_on_select=False)
|
82 |
-
rec_repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space"], value=["model", "dataset", "space"])
|
83 |
with gr.Row(equal_height=True):
|
84 |
rec_sort = gr.Radio(label="Sort", choices=["last_modified", "likes", "downloads", "downloads_all_time", "trending_score"], value="likes")
|
85 |
rec_limit = gr.Number(label="Limit", value=20, step=1, minimum=1, maximum=1000)
|
|
|
17 |
with gr.Tab("Normal Search"):
|
18 |
with gr.Group():
|
19 |
with gr.Row(equal_height=True):
|
20 |
+
repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space", "collection"], value=["model", "dataset", "space"])
|
21 |
+
filter_str = gr.Textbox(label="Filter", info="String(s) to filter repos", value="")
|
22 |
with gr.Accordion("Advanced", open=False):
|
23 |
with gr.Row(equal_height=True):
|
|
|
24 |
search_str = gr.Textbox(label="Search", info="A string that will be contained in the returned repo ids", placeholder="bert", value="", lines=1)
|
25 |
author = gr.Textbox(label="Author", info="The author (user or organization)", value="", lines=1)
|
26 |
with gr.Column():
|
|
|
79 |
#rec_repo_id = gr.Textbox(label="Repo ID", info="Input your favorite repo", value="")
|
80 |
rec_repo_id = HuggingfaceHubSearch(label="Repo ID", placeholder="Input your favorite Repo ID", search_type=["model", "dataset", "space"],
|
81 |
sumbit_on_select=False)
|
82 |
+
rec_repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space", "collection"], value=["model", "dataset", "space", "collection"])
|
83 |
with gr.Row(equal_height=True):
|
84 |
rec_sort = gr.Radio(label="Sort", choices=["last_modified", "likes", "downloads", "downloads_all_time", "trending_score"], value="likes")
|
85 |
rec_limit = gr.Number(label="Limit", value=20, step=1, minimum=1, maximum=1000)
|
hfsearch.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
-
from huggingface_hub import HfApi, ModelInfo, DatasetInfo, SpaceInfo
|
|
|
4 |
from typing import Union
|
5 |
import gc
|
6 |
import pandas as pd
|
@@ -16,8 +17,11 @@ def dummy_gpu():
|
|
16 |
RESULT_ITEMS = {
|
17 |
"Type": [1, "str", True],
|
18 |
"ID": [2, "markdown", True, "40%"],
|
19 |
-
"
|
20 |
-
"
|
|
|
|
|
|
|
21 |
"Likes": [10, "number", True],
|
22 |
"DLs": [12, "number", True],
|
23 |
"AllDLs": [13, "number", False],
|
@@ -30,6 +34,14 @@ RESULT_ITEMS = {
|
|
30 |
"NFAA": [40, "str", False],
|
31 |
}
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
try:
|
34 |
with open("tags.json", encoding="utf-8") as f:
|
35 |
TAGS = json.load(f)
|
@@ -122,7 +134,6 @@ def get_repo_collections(repo_id: str, repo_type: str="model", limit=10):
|
|
122 |
for c in cols:
|
123 |
col = api.get_collection(collection_slug=c.slug)
|
124 |
for i in col.items:
|
125 |
-
if i.item_type == "paper": continue
|
126 |
id = i.item_id
|
127 |
cols_dict[id] = cols_dict.get(id, 1) + 1
|
128 |
types_dict[id] = i.item_type
|
@@ -145,7 +156,6 @@ def get_users_collections(users: list[str], limit=10):
|
|
145 |
for c in cols:
|
146 |
col = api.get_collection(collection_slug=c.slug)
|
147 |
for i in col.items:
|
148 |
-
if i.item_type == "paper": continue
|
149 |
id = i.item_id
|
150 |
cols_dict[id] = cols_dict.get(id, 1) + 1
|
151 |
types_dict[id] = i.item_type
|
@@ -176,6 +186,42 @@ def get_ref_repos(repo_id: str):
|
|
176 |
counts_list = list(refs.values())
|
177 |
return refs_list, types_list, counts_list
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
def str_to_list(s: str):
|
180 |
try:
|
181 |
m = re.split("\n", s)
|
@@ -263,25 +309,47 @@ class HFSearchResult():
|
|
263 |
if isinstance(i, ModelInfo): type = "model"
|
264 |
elif isinstance(i, DatasetInfo): type = "dataset"
|
265 |
elif isinstance(i, SpaceInfo): type = "space"
|
|
|
|
|
266 |
else: return
|
267 |
self._set(type, "Type")
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
if i.
|
276 |
-
if i.
|
277 |
-
|
278 |
-
if i.
|
279 |
-
if
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
self._set(i.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
self._next()
|
286 |
|
287 |
def search(self, repo_types: list, sort: str, sort_method: str, filter_str: str, search_str: str, author: str, tags: str, infer: str, gated: str, appr: list[str],
|
@@ -294,13 +362,22 @@ class HFSearchResult():
|
|
294 |
mkwargs = {}
|
295 |
dkwargs = {}
|
296 |
skwargs = {}
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
298 |
if search_str: kwargs["search"] = search_str
|
299 |
-
if author:
|
|
|
|
|
300 |
if tags and is_valid_arg(tags):
|
301 |
mkwargs["tags"] = str_to_list(tags)
|
302 |
dkwargs["tags"] = str_to_list(tags)
|
303 |
-
if limit > 0:
|
|
|
|
|
304 |
if sort_method == "descending order": kwargs["direction"] = -1
|
305 |
if gated == "gated":
|
306 |
mkwargs["gated"] = True
|
@@ -332,11 +409,15 @@ class HFSearchResult():
|
|
332 |
if len(hardware) > 0 and space.runtime.stage == "RUNNING" and space.runtime.hardware not in hardware: continue
|
333 |
if len(stage) > 0 and space.runtime.stage not in stage: continue
|
334 |
self.add_item(space)
|
335 |
-
if
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
|
|
|
|
|
|
|
|
340 |
except Exception as e:
|
341 |
raise Exception(f"Search error: {e}") from e
|
342 |
|
@@ -345,19 +426,20 @@ class HFSearchResult():
|
|
345 |
self.reset()
|
346 |
self.show_labels = show_labels.copy()
|
347 |
api = HfApi()
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
|
|
361 |
except Exception as e:
|
362 |
raise Exception(f"Search error: {e}") from e
|
363 |
|
@@ -410,9 +492,9 @@ class HFSearchResult():
|
|
410 |
return sdf
|
411 |
|
412 |
def id_to_md(df: pd.DataFrame):
|
413 |
-
if df["Type"] == "
|
414 |
-
elif df["Type"] == "
|
415 |
-
else: return f'[{df["ID"]}](
|
416 |
|
417 |
def format_md_df(df: pd.DataFrame):
|
418 |
df["ID"] = df.apply(id_to_md, axis=1)
|
@@ -460,6 +542,12 @@ class HFSearchResult():
|
|
460 |
|
461 |
def sort(self, key="Likes"):
|
462 |
if len(self.item_list) == 0: raise Exception("No item found.")
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
if not key in self.labels.get()[0]: key = "Likes"
|
464 |
self.item_list, self.item_hide_flags, self.item_info_list = zip(*sorted(zip(self.item_list, self.item_hide_flags, self.item_info_list), key=lambda x: x[0][key], reverse=True))
|
465 |
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
+
from huggingface_hub import HfApi, ModelInfo, DatasetInfo, SpaceInfo, Collection
|
4 |
+
from huggingface_hub.hf_api import PaperInfo
|
5 |
from typing import Union
|
6 |
import gc
|
7 |
import pandas as pd
|
|
|
17 |
RESULT_ITEMS = {
|
18 |
"Type": [1, "str", True],
|
19 |
"ID": [2, "markdown", True, "40%"],
|
20 |
+
"User": [4, "str", False],
|
21 |
+
"Name": [5, "str", False],
|
22 |
+
"URL": [6, "str", False],
|
23 |
+
"Status": [7, "markdown", True],
|
24 |
+
"Gated": [8, "str", True],
|
25 |
"Likes": [10, "number", True],
|
26 |
"DLs": [12, "number", True],
|
27 |
"AllDLs": [13, "number", False],
|
|
|
34 |
"NFAA": [40, "str", False],
|
35 |
}
|
36 |
|
37 |
+
SORT_PARAM_TO_ITEM = {
|
38 |
+
"last_modified": "LastMod.",
|
39 |
+
"likes": "Likes",
|
40 |
+
"downloads": "DLs",
|
41 |
+
"downloads_all_time": "AllDLs",
|
42 |
+
"trending_score": "Trending",
|
43 |
+
}
|
44 |
+
|
45 |
try:
|
46 |
with open("tags.json", encoding="utf-8") as f:
|
47 |
TAGS = json.load(f)
|
|
|
134 |
for c in cols:
|
135 |
col = api.get_collection(collection_slug=c.slug)
|
136 |
for i in col.items:
|
|
|
137 |
id = i.item_id
|
138 |
cols_dict[id] = cols_dict.get(id, 1) + 1
|
139 |
types_dict[id] = i.item_type
|
|
|
156 |
for c in cols:
|
157 |
col = api.get_collection(collection_slug=c.slug)
|
158 |
for i in col.items:
|
|
|
159 |
id = i.item_id
|
160 |
cols_dict[id] = cols_dict.get(id, 1) + 1
|
161 |
types_dict[id] = i.item_type
|
|
|
186 |
counts_list = list(refs.values())
|
187 |
return refs_list, types_list, counts_list
|
188 |
|
189 |
+
def get_collections_by_repo(repo_id: str, repo_type: str="model", limit=100):
|
190 |
+
try:
|
191 |
+
api = HfApi()
|
192 |
+
if repo_type == "dataset": item = f"datasets/{repo_id}"
|
193 |
+
elif repo_type == "space": item = f"spaces/{repo_id}"
|
194 |
+
else: item = f"models/{repo_id}"
|
195 |
+
cols = api.list_collections(item=item, sort="upvotes", limit=limit)
|
196 |
+
return [c for c in cols]
|
197 |
+
except Exception as e:
|
198 |
+
print(e)
|
199 |
+
raise Exception(e)
|
200 |
+
|
201 |
+
def get_collections_by_users(users: list[str], limit=100):
|
202 |
+
try:
|
203 |
+
api = HfApi()
|
204 |
+
cols_list = []
|
205 |
+
for user in users[0:6]:
|
206 |
+
cols = api.list_collections(owner=user, sort="upvotes", limit=limit)
|
207 |
+
for col in cols:
|
208 |
+
cols_list.append(col)
|
209 |
+
return cols_list
|
210 |
+
except Exception as e:
|
211 |
+
print(e)
|
212 |
+
raise Exception(e)
|
213 |
+
|
214 |
+
def get_ref_collections(repo_id: str, limit=10):
|
215 |
+
try:
|
216 |
+
repo_type = get_repo_type(repo_id)
|
217 |
+
likers = get_repo_likers(repo_id, repo_type)[0:10]
|
218 |
+
cols = get_collections_by_repo(repo_id, repo_type, limit) + get_collections_by_users(likers, limit)
|
219 |
+
cols = list({k.slug: k for k in cols}.values())
|
220 |
+
return cols
|
221 |
+
except Exception as e:
|
222 |
+
print(e)
|
223 |
+
raise Exception(e)
|
224 |
+
|
225 |
def str_to_list(s: str):
|
226 |
try:
|
227 |
m = re.split("\n", s)
|
|
|
309 |
if isinstance(i, ModelInfo): type = "model"
|
310 |
elif isinstance(i, DatasetInfo): type = "dataset"
|
311 |
elif isinstance(i, SpaceInfo): type = "space"
|
312 |
+
elif isinstance(i, PaperInfo): type = "paper"
|
313 |
+
elif isinstance(i, Collection): type = "collection"
|
314 |
else: return
|
315 |
self._set(type, "Type")
|
316 |
+
if type in ["space", "model", "dataset"]:
|
317 |
+
self._set(i.id, "ID")
|
318 |
+
self._set(i.id.split("/")[0], "User")
|
319 |
+
self._set(i.id.split("/")[1], "Name")
|
320 |
+
if type == "dataset": self._set(f"https://hf.co/datasets/{i.id}", "URL")
|
321 |
+
elif type == "space": self._set(f"https://hf.co/spaces/{i.id}", "URL")
|
322 |
+
else: self._set(f"https://hf.co/{i.id}", "URL")
|
323 |
+
if i.likes is not None: self._set(i.likes, "Likes")
|
324 |
+
if i.last_modified is not None: self._set(date_to_str(i.last_modified), "LastMod.")
|
325 |
+
if i.trending_score is not None: self._set(int(i.trending_score), "Trending")
|
326 |
+
if i.tags is not None: self._set("True" if "not-for-all-audiences" in i.tags else "False", "NFAA")
|
327 |
+
if type in ["model", "dataset"]:
|
328 |
+
if i.gated is not None: self._set(i.gated if i.gated else "off", "Gated")
|
329 |
+
if i.downloads is not None: self._set(i.downloads, "DLs")
|
330 |
+
if i.downloads_all_time is not None: self._set(i.downloads_all_time, "AllDLs")
|
331 |
+
if type == "model":
|
332 |
+
if i.inference is not None: self._set(i.inference, "Status")
|
333 |
+
if i.library_name is not None: self._set(i.library_name, "Library")
|
334 |
+
if i.pipeline_tag is not None: self._set(i.pipeline_tag, "Pipeline")
|
335 |
+
if type == "space":
|
336 |
+
if i.runtime is not None:
|
337 |
+
self._set(i.runtime.hardware, "Hardware")
|
338 |
+
self._set(i.runtime.stage, "Stage")
|
339 |
+
elif type == "paper": # https://github.com/huggingface/huggingface_hub/blob/v0.27.0/src/huggingface_hub/hf_api.py#L1428
|
340 |
+
self._set(i.id, "ID")
|
341 |
+
self._set(f"https://hf.co/papers/{i.id}", "URL")
|
342 |
+
if i.submitted_by is not None: self._set(i.submitted_by, "User")
|
343 |
+
if i.title is not None: self._set(i.title, "Name")
|
344 |
+
if i.submitted_at is not None: self._set(date_to_str(i.submitted_at), "LastMod.")
|
345 |
+
if i.upvotes is not None: self._set(i.upvotes, "Likes")
|
346 |
+
elif type == "collection":
|
347 |
+
self._set(i.slug, "ID")
|
348 |
+
if i.owner is not None: self._set(i.owner["name"], "User")
|
349 |
+
if i.title is not None: self._set(i.title, "Name")
|
350 |
+
if i.last_updated is not None: self._set(date_to_str(i.last_updated), "LastMod.")
|
351 |
+
if i.upvotes is not None: self._set(i.upvotes, "Likes")
|
352 |
+
if i.url is not None: self._set(i.url, "URL")
|
353 |
self._next()
|
354 |
|
355 |
def search(self, repo_types: list, sort: str, sort_method: str, filter_str: str, search_str: str, author: str, tags: str, infer: str, gated: str, appr: list[str],
|
|
|
362 |
mkwargs = {}
|
363 |
dkwargs = {}
|
364 |
skwargs = {}
|
365 |
+
ckwargs = {}
|
366 |
+
pkwargs = {}
|
367 |
+
if filter_str:
|
368 |
+
kwargs["filter"] = str_to_list(filter_str)
|
369 |
+
ckwargs["item"] = str_to_list(filter_str)
|
370 |
+
pkwargs["query"] = str_to_list(filter_str)
|
371 |
if search_str: kwargs["search"] = search_str
|
372 |
+
if author:
|
373 |
+
kwargs["author"] = author
|
374 |
+
ckwargs["owner"] = author
|
375 |
if tags and is_valid_arg(tags):
|
376 |
mkwargs["tags"] = str_to_list(tags)
|
377 |
dkwargs["tags"] = str_to_list(tags)
|
378 |
+
if limit > 0:
|
379 |
+
kwargs["limit"] = limit
|
380 |
+
ckwargs["limit"] = 100 if limit > 100 else limit
|
381 |
if sort_method == "descending order": kwargs["direction"] = -1
|
382 |
if gated == "gated":
|
383 |
mkwargs["gated"] = True
|
|
|
409 |
if len(hardware) > 0 and space.runtime.stage == "RUNNING" and space.runtime.hardware not in hardware: continue
|
410 |
if len(stage) > 0 and space.runtime.stage not in stage: continue
|
411 |
self.add_item(space)
|
412 |
+
if "paper" in repo_types:
|
413 |
+
papers = api.list_papers(**pkwargs)
|
414 |
+
for paper in papers:
|
415 |
+
self.add_item(paper)
|
416 |
+
if "collection" in repo_types:
|
417 |
+
cols = api.list_collections(**ckwargs)
|
418 |
+
for col in cols:
|
419 |
+
self.add_item(col)
|
420 |
+
self.sort(sort)
|
421 |
except Exception as e:
|
422 |
raise Exception(f"Search error: {e}") from e
|
423 |
|
|
|
426 |
self.reset()
|
427 |
self.show_labels = show_labels.copy()
|
428 |
api = HfApi()
|
429 |
+
if "model" in repo_types or "dataset" in repo_types or "space" in repo_types or "paper" in repo_types:
|
430 |
+
repos, types, counts = get_ref_repos(repo_id)
|
431 |
+
i = 0
|
432 |
+
for r, t in zip(repos, types):
|
433 |
+
if i + 1 > limit: break
|
434 |
+
i += 1
|
435 |
+
if t not in repo_types: continue
|
436 |
+
info = api.repo_info(repo_id=r, repo_type=t)
|
437 |
+
if info: self.add_item(info)
|
438 |
+
if "collection" in repo_types:
|
439 |
+
cols = get_ref_collections(repo_id, limit)
|
440 |
+
for col in cols:
|
441 |
+
self.add_item(col)
|
442 |
+
self.sort(sort)
|
443 |
except Exception as e:
|
444 |
raise Exception(f"Search error: {e}") from e
|
445 |
|
|
|
492 |
return sdf
|
493 |
|
494 |
def id_to_md(df: pd.DataFrame):
|
495 |
+
if df["Type"] == "collection": return f'[{df["User"]}: {df["Name"]}]({df["URL"]})'
|
496 |
+
elif df["Type"] == "paper": return f'[{df["Name"]} (arxiv:{df["ID"]})]({df["URL"]})'
|
497 |
+
else: return f'[{df["ID"]}]({df["URL"]})'
|
498 |
|
499 |
def format_md_df(df: pd.DataFrame):
|
500 |
df["ID"] = df.apply(id_to_md, axis=1)
|
|
|
542 |
|
543 |
def sort(self, key="Likes"):
|
544 |
if len(self.item_list) == 0: raise Exception("No item found.")
|
545 |
+
if key in SORT_PARAM_TO_ITEM.keys(): key = SORT_PARAM_TO_ITEM[key]
|
546 |
+
types = set()
|
547 |
+
for i in self.item_list:
|
548 |
+
if "Type" in i.keys(): types.add(i["Type"])
|
549 |
+
if "paper" in types: return
|
550 |
+
if key in ["DLs", "AllDLs"] and ("space" in types or "collection" in types): key = "Likes"
|
551 |
if not key in self.labels.get()[0]: key = "Likes"
|
552 |
self.item_list, self.item_hide_flags, self.item_info_list = zip(*sorted(zip(self.item_list, self.item_hide_flags, self.item_info_list), key=lambda x: x[0][key], reverse=True))
|
553 |
|