John6666 commited on
Commit
b53c1d8
·
verified ·
1 Parent(s): 68eb6f0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. hfsearch.py +134 -46
app.py CHANGED
@@ -17,10 +17,10 @@ with gr.Blocks(theme="NoCrypt/miku", fill_width=True, css=CSS) as demo:
17
  with gr.Tab("Normal Search"):
18
  with gr.Group():
19
  with gr.Row(equal_height=True):
20
- repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space"], value=["model", "dataset", "space"])
 
21
  with gr.Accordion("Advanced", open=False):
22
  with gr.Row(equal_height=True):
23
- filter_str = gr.Textbox(label="Filter", info="String(s) to filter repos", value="")
24
  search_str = gr.Textbox(label="Search", info="A string that will be contained in the returned repo ids", placeholder="bert", value="", lines=1)
25
  author = gr.Textbox(label="Author", info="The author (user or organization)", value="", lines=1)
26
  with gr.Column():
@@ -79,7 +79,7 @@ with gr.Blocks(theme="NoCrypt/miku", fill_width=True, css=CSS) as demo:
79
  #rec_repo_id = gr.Textbox(label="Repo ID", info="Input your favorite repo", value="")
80
  rec_repo_id = HuggingfaceHubSearch(label="Repo ID", placeholder="Input your favorite Repo ID", search_type=["model", "dataset", "space"],
81
  sumbit_on_select=False)
82
- rec_repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space"], value=["model", "dataset", "space"])
83
  with gr.Row(equal_height=True):
84
  rec_sort = gr.Radio(label="Sort", choices=["last_modified", "likes", "downloads", "downloads_all_time", "trending_score"], value="likes")
85
  rec_limit = gr.Number(label="Limit", value=20, step=1, minimum=1, maximum=1000)
 
17
  with gr.Tab("Normal Search"):
18
  with gr.Group():
19
  with gr.Row(equal_height=True):
20
+ repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space", "collection"], value=["model", "dataset", "space"])
21
+ filter_str = gr.Textbox(label="Filter", info="String(s) to filter repos", value="")
22
  with gr.Accordion("Advanced", open=False):
23
  with gr.Row(equal_height=True):
 
24
  search_str = gr.Textbox(label="Search", info="A string that will be contained in the returned repo ids", placeholder="bert", value="", lines=1)
25
  author = gr.Textbox(label="Author", info="The author (user or organization)", value="", lines=1)
26
  with gr.Column():
 
79
  #rec_repo_id = gr.Textbox(label="Repo ID", info="Input your favorite repo", value="")
80
  rec_repo_id = HuggingfaceHubSearch(label="Repo ID", placeholder="Input your favorite Repo ID", search_type=["model", "dataset", "space"],
81
  sumbit_on_select=False)
82
+ rec_repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space", "collection"], value=["model", "dataset", "space", "collection"])
83
  with gr.Row(equal_height=True):
84
  rec_sort = gr.Radio(label="Sort", choices=["last_modified", "likes", "downloads", "downloads_all_time", "trending_score"], value="likes")
85
  rec_limit = gr.Number(label="Limit", value=20, step=1, minimum=1, maximum=1000)
hfsearch.py CHANGED
@@ -1,6 +1,7 @@
1
  import spaces
2
  import gradio as gr
3
- from huggingface_hub import HfApi, ModelInfo, DatasetInfo, SpaceInfo
 
4
  from typing import Union
5
  import gc
6
  import pandas as pd
@@ -16,8 +17,11 @@ def dummy_gpu():
16
  RESULT_ITEMS = {
17
  "Type": [1, "str", True],
18
  "ID": [2, "markdown", True, "40%"],
19
- "Status": [4, "markdown", True],
20
- "Gated": [6, "str", True],
 
 
 
21
  "Likes": [10, "number", True],
22
  "DLs": [12, "number", True],
23
  "AllDLs": [13, "number", False],
@@ -30,6 +34,14 @@ RESULT_ITEMS = {
30
  "NFAA": [40, "str", False],
31
  }
32
 
 
 
 
 
 
 
 
 
33
  try:
34
  with open("tags.json", encoding="utf-8") as f:
35
  TAGS = json.load(f)
@@ -122,7 +134,6 @@ def get_repo_collections(repo_id: str, repo_type: str="model", limit=10):
122
  for c in cols:
123
  col = api.get_collection(collection_slug=c.slug)
124
  for i in col.items:
125
- if i.item_type == "paper": continue
126
  id = i.item_id
127
  cols_dict[id] = cols_dict.get(id, 1) + 1
128
  types_dict[id] = i.item_type
@@ -145,7 +156,6 @@ def get_users_collections(users: list[str], limit=10):
145
  for c in cols:
146
  col = api.get_collection(collection_slug=c.slug)
147
  for i in col.items:
148
- if i.item_type == "paper": continue
149
  id = i.item_id
150
  cols_dict[id] = cols_dict.get(id, 1) + 1
151
  types_dict[id] = i.item_type
@@ -176,6 +186,42 @@ def get_ref_repos(repo_id: str):
176
  counts_list = list(refs.values())
177
  return refs_list, types_list, counts_list
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def str_to_list(s: str):
180
  try:
181
  m = re.split("\n", s)
@@ -263,25 +309,47 @@ class HFSearchResult():
263
  if isinstance(i, ModelInfo): type = "model"
264
  elif isinstance(i, DatasetInfo): type = "dataset"
265
  elif isinstance(i, SpaceInfo): type = "space"
 
 
266
  else: return
267
  self._set(type, "Type")
268
- self._set(i.id, "ID")
269
- if i.likes is not None: self._set(i.likes, "Likes")
270
- if i.last_modified is not None: self._set(date_to_str(i.last_modified), "LastMod.")
271
- if i.trending_score is not None: self._set(int(i.trending_score), "Trending")
272
- if i.tags is not None: self._set("True" if "not-for-all-audiences" in i.tags else "False", "NFAA")
273
- if type in ["model", "dataset"]:
274
- if i.gated is not None: self._set(i.gated if i.gated else "off", "Gated")
275
- if i.downloads is not None: self._set(i.downloads, "DLs")
276
- if i.downloads_all_time is not None: self._set(i.downloads_all_time, "AllDLs")
277
- if type == "model":
278
- if i.inference is not None: self._set(i.inference, "Status")
279
- if i.library_name is not None: self._set(i.library_name, "Library")
280
- if i.pipeline_tag is not None: self._set(i.pipeline_tag, "Pipeline")
281
- if type == "space":
282
- if i.runtime is not None:
283
- self._set(i.runtime.hardware, "Hardware")
284
- self._set(i.runtime.stage, "Stage")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  self._next()
286
 
287
  def search(self, repo_types: list, sort: str, sort_method: str, filter_str: str, search_str: str, author: str, tags: str, infer: str, gated: str, appr: list[str],
@@ -294,13 +362,22 @@ class HFSearchResult():
294
  mkwargs = {}
295
  dkwargs = {}
296
  skwargs = {}
297
- if filter_str: kwargs["filter"] = str_to_list(filter_str)
 
 
 
 
 
298
  if search_str: kwargs["search"] = search_str
299
- if author: kwargs["author"] = author
 
 
300
  if tags and is_valid_arg(tags):
301
  mkwargs["tags"] = str_to_list(tags)
302
  dkwargs["tags"] = str_to_list(tags)
303
- if limit > 0: kwargs["limit"] = limit
 
 
304
  if sort_method == "descending order": kwargs["direction"] = -1
305
  if gated == "gated":
306
  mkwargs["gated"] = True
@@ -332,11 +409,15 @@ class HFSearchResult():
332
  if len(hardware) > 0 and space.runtime.stage == "RUNNING" and space.runtime.hardware not in hardware: continue
333
  if len(stage) > 0 and space.runtime.stage not in stage: continue
334
  self.add_item(space)
335
- if sort == "downloads" and ("space" not in repo_types): self.sort("DLs")
336
- elif sort == "downloads_all_time" and ("space" not in repo_types): self.sort("AllDLs")
337
- elif sort == "likes": self.sort("Likes")
338
- elif sort == "trending_score": self.sort("Trending")
339
- else: self.sort("LastMod.")
 
 
 
 
340
  except Exception as e:
341
  raise Exception(f"Search error: {e}") from e
342
 
@@ -345,19 +426,20 @@ class HFSearchResult():
345
  self.reset()
346
  self.show_labels = show_labels.copy()
347
  api = HfApi()
348
- repos, types, counts = get_ref_repos(repo_id)
349
- i = 0
350
- for r, t in zip(repos, types):
351
- if i + 1 > limit: break
352
- i += 1
353
- if t not in repo_types: continue
354
- info = api.repo_info(repo_id=r, repo_type=t)
355
- if info: self.add_item(info)
356
- if sort == "downloads" and ("space" not in repo_types): self.sort("DLs")
357
- elif sort == "downloads_all_time" and ("space" not in repo_types): self.sort("AllDLs")
358
- elif sort == "likes": self.sort("Likes")
359
- elif sort == "trending_score": self.sort("Trending")
360
- else: self.sort("LastMod.")
 
361
  except Exception as e:
362
  raise Exception(f"Search error: {e}") from e
363
 
@@ -410,9 +492,9 @@ class HFSearchResult():
410
  return sdf
411
 
412
  def id_to_md(df: pd.DataFrame):
413
- if df["Type"] == "dataset": return f'[{df["ID"]}](https://hf.co/datasets/{df["ID"]})'
414
- elif df["Type"] == "space": return f'[{df["ID"]}](https://hf.co/spaces/{df["ID"]})'
415
- else: return f'[{df["ID"]}](https://hf.co/{df["ID"]})'
416
 
417
  def format_md_df(df: pd.DataFrame):
418
  df["ID"] = df.apply(id_to_md, axis=1)
@@ -460,6 +542,12 @@ class HFSearchResult():
460
 
461
  def sort(self, key="Likes"):
462
  if len(self.item_list) == 0: raise Exception("No item found.")
 
 
 
 
 
 
463
  if not key in self.labels.get()[0]: key = "Likes"
464
  self.item_list, self.item_hide_flags, self.item_info_list = zip(*sorted(zip(self.item_list, self.item_hide_flags, self.item_info_list), key=lambda x: x[0][key], reverse=True))
465
 
 
1
  import spaces
2
  import gradio as gr
3
+ from huggingface_hub import HfApi, ModelInfo, DatasetInfo, SpaceInfo, Collection
4
+ from huggingface_hub.hf_api import PaperInfo
5
  from typing import Union
6
  import gc
7
  import pandas as pd
 
17
  RESULT_ITEMS = {
18
  "Type": [1, "str", True],
19
  "ID": [2, "markdown", True, "40%"],
20
+ "User": [4, "str", False],
21
+ "Name": [5, "str", False],
22
+ "URL": [6, "str", False],
23
+ "Status": [7, "markdown", True],
24
+ "Gated": [8, "str", True],
25
  "Likes": [10, "number", True],
26
  "DLs": [12, "number", True],
27
  "AllDLs": [13, "number", False],
 
34
  "NFAA": [40, "str", False],
35
  }
36
 
37
+ SORT_PARAM_TO_ITEM = {
38
+ "last_modified": "LastMod.",
39
+ "likes": "Likes",
40
+ "downloads": "DLs",
41
+ "downloads_all_time": "AllDLs",
42
+ "trending_score": "Trending",
43
+ }
44
+
45
  try:
46
  with open("tags.json", encoding="utf-8") as f:
47
  TAGS = json.load(f)
 
134
  for c in cols:
135
  col = api.get_collection(collection_slug=c.slug)
136
  for i in col.items:
 
137
  id = i.item_id
138
  cols_dict[id] = cols_dict.get(id, 1) + 1
139
  types_dict[id] = i.item_type
 
156
  for c in cols:
157
  col = api.get_collection(collection_slug=c.slug)
158
  for i in col.items:
 
159
  id = i.item_id
160
  cols_dict[id] = cols_dict.get(id, 1) + 1
161
  types_dict[id] = i.item_type
 
186
  counts_list = list(refs.values())
187
  return refs_list, types_list, counts_list
188
 
189
+ def get_collections_by_repo(repo_id: str, repo_type: str="model", limit=100):
190
+ try:
191
+ api = HfApi()
192
+ if repo_type == "dataset": item = f"datasets/{repo_id}"
193
+ elif repo_type == "space": item = f"spaces/{repo_id}"
194
+ else: item = f"models/{repo_id}"
195
+ cols = api.list_collections(item=item, sort="upvotes", limit=limit)
196
+ return [c for c in cols]
197
+ except Exception as e:
198
+ print(e)
199
+ raise Exception(e)
200
+
201
+ def get_collections_by_users(users: list[str], limit=100):
202
+ try:
203
+ api = HfApi()
204
+ cols_list = []
205
+ for user in users[0:6]:
206
+ cols = api.list_collections(owner=user, sort="upvotes", limit=limit)
207
+ for col in cols:
208
+ cols_list.append(col)
209
+ return cols_list
210
+ except Exception as e:
211
+ print(e)
212
+ raise Exception(e)
213
+
214
+ def get_ref_collections(repo_id: str, limit=10):
215
+ try:
216
+ repo_type = get_repo_type(repo_id)
217
+ likers = get_repo_likers(repo_id, repo_type)[0:10]
218
+ cols = get_collections_by_repo(repo_id, repo_type, limit) + get_collections_by_users(likers, limit)
219
+ cols = list({k.slug: k for k in cols}.values())
220
+ return cols
221
+ except Exception as e:
222
+ print(e)
223
+ raise Exception(e)
224
+
225
  def str_to_list(s: str):
226
  try:
227
  m = re.split("\n", s)
 
309
  if isinstance(i, ModelInfo): type = "model"
310
  elif isinstance(i, DatasetInfo): type = "dataset"
311
  elif isinstance(i, SpaceInfo): type = "space"
312
+ elif isinstance(i, PaperInfo): type = "paper"
313
+ elif isinstance(i, Collection): type = "collection"
314
  else: return
315
  self._set(type, "Type")
316
+ if type in ["space", "model", "dataset"]:
317
+ self._set(i.id, "ID")
318
+ self._set(i.id.split("/")[0], "User")
319
+ self._set(i.id.split("/")[1], "Name")
320
+ if type == "dataset": self._set(f"https://hf.co/datasets/{i.id}", "URL")
321
+ elif type == "space": self._set(f"https://hf.co/spaces/{i.id}", "URL")
322
+ else: self._set(f"https://hf.co/{i.id}", "URL")
323
+ if i.likes is not None: self._set(i.likes, "Likes")
324
+ if i.last_modified is not None: self._set(date_to_str(i.last_modified), "LastMod.")
325
+ if i.trending_score is not None: self._set(int(i.trending_score), "Trending")
326
+ if i.tags is not None: self._set("True" if "not-for-all-audiences" in i.tags else "False", "NFAA")
327
+ if type in ["model", "dataset"]:
328
+ if i.gated is not None: self._set(i.gated if i.gated else "off", "Gated")
329
+ if i.downloads is not None: self._set(i.downloads, "DLs")
330
+ if i.downloads_all_time is not None: self._set(i.downloads_all_time, "AllDLs")
331
+ if type == "model":
332
+ if i.inference is not None: self._set(i.inference, "Status")
333
+ if i.library_name is not None: self._set(i.library_name, "Library")
334
+ if i.pipeline_tag is not None: self._set(i.pipeline_tag, "Pipeline")
335
+ if type == "space":
336
+ if i.runtime is not None:
337
+ self._set(i.runtime.hardware, "Hardware")
338
+ self._set(i.runtime.stage, "Stage")
339
+ elif type == "paper": # https://github.com/huggingface/huggingface_hub/blob/v0.27.0/src/huggingface_hub/hf_api.py#L1428
340
+ self._set(i.id, "ID")
341
+ self._set(f"https://hf.co/papers/{i.id}", "URL")
342
+ if i.submitted_by is not None: self._set(i.submitted_by, "User")
343
+ if i.title is not None: self._set(i.title, "Name")
344
+ if i.submitted_at is not None: self._set(date_to_str(i.submitted_at), "LastMod.")
345
+ if i.upvotes is not None: self._set(i.upvotes, "Likes")
346
+ elif type == "collection":
347
+ self._set(i.slug, "ID")
348
+ if i.owner is not None: self._set(i.owner["name"], "User")
349
+ if i.title is not None: self._set(i.title, "Name")
350
+ if i.last_updated is not None: self._set(date_to_str(i.last_updated), "LastMod.")
351
+ if i.upvotes is not None: self._set(i.upvotes, "Likes")
352
+ if i.url is not None: self._set(i.url, "URL")
353
  self._next()
354
 
355
  def search(self, repo_types: list, sort: str, sort_method: str, filter_str: str, search_str: str, author: str, tags: str, infer: str, gated: str, appr: list[str],
 
362
  mkwargs = {}
363
  dkwargs = {}
364
  skwargs = {}
365
+ ckwargs = {}
366
+ pkwargs = {}
367
+ if filter_str:
368
+ kwargs["filter"] = str_to_list(filter_str)
369
+ ckwargs["item"] = str_to_list(filter_str)
370
+ pkwargs["query"] = str_to_list(filter_str)
371
  if search_str: kwargs["search"] = search_str
372
+ if author:
373
+ kwargs["author"] = author
374
+ ckwargs["owner"] = author
375
  if tags and is_valid_arg(tags):
376
  mkwargs["tags"] = str_to_list(tags)
377
  dkwargs["tags"] = str_to_list(tags)
378
+ if limit > 0:
379
+ kwargs["limit"] = limit
380
+ ckwargs["limit"] = 100 if limit > 100 else limit
381
  if sort_method == "descending order": kwargs["direction"] = -1
382
  if gated == "gated":
383
  mkwargs["gated"] = True
 
409
  if len(hardware) > 0 and space.runtime.stage == "RUNNING" and space.runtime.hardware not in hardware: continue
410
  if len(stage) > 0 and space.runtime.stage not in stage: continue
411
  self.add_item(space)
412
+ if "paper" in repo_types:
413
+ papers = api.list_papers(**pkwargs)
414
+ for paper in papers:
415
+ self.add_item(paper)
416
+ if "collection" in repo_types:
417
+ cols = api.list_collections(**ckwargs)
418
+ for col in cols:
419
+ self.add_item(col)
420
+ self.sort(sort)
421
  except Exception as e:
422
  raise Exception(f"Search error: {e}") from e
423
 
 
426
  self.reset()
427
  self.show_labels = show_labels.copy()
428
  api = HfApi()
429
+ if "model" in repo_types or "dataset" in repo_types or "space" in repo_types or "paper" in repo_types:
430
+ repos, types, counts = get_ref_repos(repo_id)
431
+ i = 0
432
+ for r, t in zip(repos, types):
433
+ if i + 1 > limit: break
434
+ i += 1
435
+ if t not in repo_types: continue
436
+ info = api.repo_info(repo_id=r, repo_type=t)
437
+ if info: self.add_item(info)
438
+ if "collection" in repo_types:
439
+ cols = get_ref_collections(repo_id, limit)
440
+ for col in cols:
441
+ self.add_item(col)
442
+ self.sort(sort)
443
  except Exception as e:
444
  raise Exception(f"Search error: {e}") from e
445
 
 
492
  return sdf
493
 
494
  def id_to_md(df: pd.DataFrame):
495
+ if df["Type"] == "collection": return f'[{df["User"]}: {df["Name"]}]({df["URL"]})'
496
+ elif df["Type"] == "paper": return f'[{df["Name"]} (arxiv:{df["ID"]})]({df["URL"]})'
497
+ else: return f'[{df["ID"]}]({df["URL"]})'
498
 
499
  def format_md_df(df: pd.DataFrame):
500
  df["ID"] = df.apply(id_to_md, axis=1)
 
542
 
543
  def sort(self, key="Likes"):
544
  if len(self.item_list) == 0: raise Exception("No item found.")
545
+ if key in SORT_PARAM_TO_ITEM.keys(): key = SORT_PARAM_TO_ITEM[key]
546
+ types = set()
547
+ for i in self.item_list:
548
+ if "Type" in i.keys(): types.add(i["Type"])
549
+ if "paper" in types: return
550
+ if key in ["DLs", "AllDLs"] and ("space" in types or "collection" in types): key = "Likes"
551
  if not key in self.labels.get()[0]: key = "Likes"
552
  self.item_list, self.item_hide_flags, self.item_info_list = zip(*sorted(zip(self.item_list, self.item_hide_flags, self.item_info_list), key=lambda x: x[0][key], reverse=True))
553