John6666 commited on
Commit
68eb6f0
·
verified ·
1 Parent(s): 6383f6b

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +13 -12
  2. app.py +116 -0
  3. hfconstants.py +7 -0
  4. hfsearch.py +514 -0
  5. pre-requirements.txt +1 -0
  6. requirements.txt +2 -0
  7. subtags.json +0 -0
  8. tags.json +0 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Hfsearch
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: Hugging Face🤗 Searcher
3
+ emoji: 🤗🔍
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from hfsearch import (HFSearchResult, search, update_filter, update_df, get_labels, get_valid_labels,
4
+ get_tags, get_subtag_categories, update_subtag_items, update_tags, update_subtags,
5
+ search_ref_repos, DS_SIZE_CATEGORIES, SPACE_HARDWARES, SPACE_STAGES)
6
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
7
+
8
+ CSS = """
9
+ .title { align-items: center; text-align: center; }
10
+ .info { align-items: center; text-align: center; }
11
+ """
12
+
13
+ with gr.Blocks(theme="NoCrypt/miku", fill_width=True, css=CSS) as demo:
14
+ gr.Markdown("# Search Hugging Face🤗", elem_classes="title")
15
+ with gr.Column():
16
+ search_result = gr.State(value=HFSearchResult())
17
+ with gr.Tab("Normal Search"):
18
+ with gr.Group():
19
+ with gr.Row(equal_height=True):
20
+ repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space"], value=["model", "dataset", "space"])
21
+ with gr.Accordion("Advanced", open=False):
22
+ with gr.Row(equal_height=True):
23
+ filter_str = gr.Textbox(label="Filter", info="String(s) to filter repos", value="")
24
+ search_str = gr.Textbox(label="Search", info="A string that will be contained in the returned repo ids", placeholder="bert", value="", lines=1)
25
+ author = gr.Textbox(label="Author", info="The author (user or organization)", value="", lines=1)
26
+ with gr.Column():
27
+ tags = gr.Textbox(label="Tags", info="Tag(s) to filter repos", value="")
28
+ with gr.Accordion("Tag input assistance", open=False):
29
+ with gr.Row(equal_height=True):
30
+ tag_item = gr.Dropdown(label="Item", choices=get_tags(), value=get_tags()[0], allow_custom_value=True, scale=4)
31
+ tag_btn = gr.Button("Add", scale=1)
32
+ with gr.Row(equal_height=True):
33
+ subtag_cat = gr.Dropdown(label="Category", choices=get_subtag_categories(), value=get_subtag_categories()[0], scale=2)
34
+ subtag_item = gr.Dropdown(label="Item", choices=[""], value="", allow_custom_value=True, scale=2)
35
+ subtug_btn = gr.Button("Add", scale=1)
36
+ with gr.Column():
37
+ gated_status = gr.Radio(label="Gated status", choices=["gated", "non-gated", "all"], value="all")
38
+ appr_status = gr.CheckboxGroup(label="Approval method", choices=["auto", "manual"], value=["auto", "manual"])
39
+ with gr.Tab("for Models"):
40
+ with gr.Column():
41
+ infer_status = gr.Radio(label="Inference status", choices=["warm", "cold", "frozen", "all"], value="all")
42
+ gr.Markdown("[About the Inference API status (Warm, Cold, Frozen)](https://huggingface.co/docs/api-inference/supported-models)", elem_classes="info")
43
+ # with gr.Row(equal_height=True):
44
+ # model_task = gr.Textbox(label="Task", info="String(s) of tasks models were designed for", placeholder="fill-mask", value="")
45
+ # trained_dataset = gr.Textbox(label="Trained dataset", info="Trained dataset for a model", value="")
46
+ with gr.Tab("for Datasets"):
47
+ size_categories = gr.CheckboxGroup(label="Size categories", info="The size of the dataset", choices=DS_SIZE_CATEGORIES, value=[])
48
+ # task_categories = gr.Textbox(label="Task categories", info="Identify datasets by the designed task", value="")
49
+ # task_ids = gr.Textbox(label="Task IDs", info="Identify datasets by the specific task", value="")
50
+ # language_creators = gr.Textbox(label="Language creators", info="Identify datasets with how the data was curated", value="")
51
+ # language = gr.Textbox(label="Language", info="String(s) representing two-character language to filter datasets by", value="")
52
+ # multilinguality = gr.Textbox(label="Multilinguality", info="String(s) representing a filter for datasets that contain multiple languages", value="")
53
+ with gr.Tab("for Spaces"):
54
+ with gr.Row(equal_height=True):
55
+ hardware = gr.CheckboxGroup(label="Specify hardware", choices=SPACE_HARDWARES, value=[])
56
+ stage = gr.CheckboxGroup(label="Specify stage", choices=SPACE_STAGES, value=[])
57
+ with gr.Row(equal_height=True):
58
+ sort = gr.Radio(label="Sort", choices=["last_modified", "likes", "downloads", "trending_score"], value="likes")
59
+ sort_method = gr.Radio(label="Sort method", choices=["ascending order", "descending order"], value="ascending order")
60
+ limit = gr.Number(label="Limit", info="If 0, fetches all models", value=1000, step=1, minimum=0, maximum=10000000)
61
+ fetch_detail = gr.CheckboxGroup(label="Fetch detail", choices=["Space Runtime"], value=["Space Runtime"])
62
+ with gr.Row(equal_height=True):
63
+ show_labels = gr.CheckboxGroup(label="Show items", choices=get_labels(), value=get_valid_labels())
64
+ run_button = gr.Button("Search", variant="primary")
65
+ with gr.Tab("Find Serverless Inference API enabled models"):
66
+ with gr.Group():
67
+ with gr.Row(equal_height=True):
68
+ infer_repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space"], value=["model"], visible=False)
69
+ with gr.Column():
70
+ infer_infer_status = gr.Radio(label="Inference status", choices=["warm", "cold", "frozen", "all"], value="warm")
71
+ gr.Markdown("[About the Inference API status (Warm, Cold, Frozen)](https://huggingface.co/docs/api-inference/supported-models)", elem_classes="info")
72
+ with gr.Column():
73
+ infer_gated_status = gr.Radio(label="Gated status", choices=["gated", "non-gated", "all"], value="all")
74
+ infer_appr_status = gr.CheckboxGroup(label="Approval method", choices=["auto", "manual"], value=["auto", "manual"])
75
+ infer_run_button = gr.Button("Search", variant="primary")
76
+ with gr.Tab("Find recommended repos"):
77
+ with gr.Group():
78
+ with gr.Row(equal_height=True):
79
+ #rec_repo_id = gr.Textbox(label="Repo ID", info="Input your favorite repo", value="")
80
+ rec_repo_id = HuggingfaceHubSearch(label="Repo ID", placeholder="Input your favorite Repo ID", search_type=["model", "dataset", "space"],
81
+ sumbit_on_select=False)
82
+ rec_repo_types = gr.CheckboxGroup(label="Repo type", choices=["model", "dataset", "space"], value=["model", "dataset", "space"])
83
+ with gr.Row(equal_height=True):
84
+ rec_sort = gr.Radio(label="Sort", choices=["last_modified", "likes", "downloads", "downloads_all_time", "trending_score"], value="likes")
85
+ rec_limit = gr.Number(label="Limit", value=20, step=1, minimum=1, maximum=1000)
86
+ with gr.Accordion("Advanced", open=False):
87
+ with gr.Row(equal_height=True):
88
+ rec_show_labels = gr.CheckboxGroup(label="Show items", choices=get_labels(), value=get_valid_labels())
89
+ rec_run_button = gr.Button("Search", variant="primary")
90
+ with gr.Group():
91
+ with gr.Accordion("Filter", open=False):
92
+ hide_labels = gr.CheckboxGroup(label="Hide items", choices=[], value=[], visible=False)
93
+ with gr.Row(equal_height=True):
94
+ filter_item1 = gr.Dropdown(label="Filter item", choices=[""], value="", visible=False)
95
+ filter1 = gr.Dropdown(label="Filter", choices=[""], value="", allow_custom_value=True, visible=False)
96
+ filter_btn = gr.Button("Apply filter", variant="secondary", visible=False)
97
+ result_df = gr.DataFrame(label="Results", type="pandas", value=None, interactive=False)
98
+
99
+ run_button.click(search, [repo_types, sort, sort_method, filter_str, search_str, author, tags, infer_status, gated_status, appr_status,
100
+ size_categories, limit, hardware, stage, fetch_detail, show_labels, search_result],
101
+ [result_df, hide_labels, search_result])\
102
+ .success(update_filter, [filter_item1, search_result], [filter_item1, filter1, filter_btn, search_result], queue=False)
103
+ infer_run_button.click(search, [infer_repo_types, sort, sort_method, filter_str, search_str, author, tags, infer_infer_status, infer_gated_status, infer_appr_status,
104
+ size_categories, limit, hardware, stage, fetch_detail, show_labels, search_result],
105
+ [result_df, hide_labels, search_result])\
106
+ .success(update_filter, [filter_item1, search_result], [filter_item1, filter1, filter_btn, search_result], queue=False)
107
+ gr.on(triggers=[hide_labels.change, filter_btn.click], fn=update_df, inputs=[hide_labels, filter_item1, filter1, search_result],
108
+ outputs=[result_df, search_result], trigger_mode="once", queue=False, show_api=False)
109
+ filter_item1.change(update_filter, [filter_item1, search_result], [filter_item1, filter1, filter_btn, search_result], queue=False, show_api=False)
110
+ subtag_cat.change(update_subtag_items, [subtag_cat], [subtag_item], queue=False, show_api=False)
111
+ subtug_btn.click(update_subtags, [tags, subtag_cat, subtag_item], [tags], queue=False, show_api=False)
112
+ tag_btn.click(update_tags, [tags, tag_item], [tags], queue=False, show_api=False)
113
+ gr.on(triggers=[rec_run_button.click, rec_repo_id.submit], fn=search_ref_repos, inputs=[rec_repo_id, rec_repo_types, rec_sort, rec_show_labels, rec_limit, search_result],
114
+ outputs=[result_df, hide_labels, search_result])
115
+
116
+ demo.queue().launch()
hfconstants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ DS_SIZE_CATEGORIES = ["n<1K", "1K<n<10K", "10K<n<100K", "100K<n<1M", "1M<n<10M", "10M<n<100M",
3
+ "100M<n<1B", "1B<n<10B", "10B<n<100B", "100B<n<1T", "n>1T"]
4
+
5
+ SPACE_HARDWARES = ["cpu-basic", "zero-a10g", "cpu-upgrade", "t4-small", "l4x1", "a10g-large", "l40sx1", "a10g-small", "t4-medium", "cpu-xl", "a100-large"]
6
+
7
+ SPACE_STAGES = ["RUNNING", "SLEEPING", "RUNTIME_ERROR", "PAUSED", "BUILD_ERROR", "CONFIG_ERROR", "BUILDING", "APP_STARTING", "RUNNING_APP_STARTING"]
hfsearch.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from huggingface_hub import HfApi, ModelInfo, DatasetInfo, SpaceInfo
4
+ from typing import Union
5
+ import gc
6
+ import pandas as pd
7
+ import datetime
8
+ import json
9
+ import re
10
+ from hfconstants import DS_SIZE_CATEGORIES, SPACE_HARDWARES, SPACE_STAGES
11
+
12
+ @spaces.GPU
13
+ def dummy_gpu():
14
+ pass
15
+
16
+ RESULT_ITEMS = {
17
+ "Type": [1, "str", True],
18
+ "ID": [2, "markdown", True, "40%"],
19
+ "Status": [4, "markdown", True],
20
+ "Gated": [6, "str", True],
21
+ "Likes": [10, "number", True],
22
+ "DLs": [12, "number", True],
23
+ "AllDLs": [13, "number", False],
24
+ "Trending": [16, "number", True],
25
+ "LastMod.": [17, "str", True],
26
+ "Library": [20, "markdown", False],
27
+ "Pipeline": [21, "markdown", True],
28
+ "Hardware": [25, "str", False],
29
+ "Stage": [26, "str", False],
30
+ "NFAA": [40, "str", False],
31
+ }
32
+
33
+ try:
34
+ with open("tags.json", encoding="utf-8") as f:
35
+ TAGS = json.load(f)
36
+ with open("subtags.json", encoding="utf-8") as f:
37
+ SUBTAGS = json.load(f)
38
+ except Exception as e:
39
+ TAGS = []
40
+ SUBTAGS = {}
41
+ print(e)
42
+
43
+ def get_tags():
44
+ return TAGS[0:1000]
45
+
46
+ def get_subtag_categories():
47
+ return list(SUBTAGS.keys())
48
+
49
+ def update_subtag_items(category: str):
50
+ choices=[""] + list(SUBTAGS.get(category, []))
51
+ return gr.update(choices=choices, value=choices[0])
52
+
53
+ def update_subtags(tags: str, category: str, item: str):
54
+ addtag = f"{category}:{item}" if item else ""
55
+ newtags = f"{tags}\n{addtag}" if tags else addtag
56
+ return newtags
57
+
58
+ def update_tags(tags: str, item: str):
59
+ newtags = f"{tags}\n{item}" if tags else item
60
+ return newtags
61
+
62
+ def get_repo_type(repo_id: str):
63
+ try:
64
+ api = HfApi()
65
+ if api.repo_exists(repo_id=repo_id, repo_type="dataset"): return "dataset"
66
+ elif api.repo_exists(repo_id=repo_id, repo_type="space"): return "space"
67
+ elif api.repo_exists(repo_id=repo_id): return "model"
68
+ else: return None
69
+ except Exception as e:
70
+ print(e)
71
+ raise Exception(f"Repo not found: {repo_id} {e}")
72
+
73
+ def sort_dict(d: dict):
74
+ return dict(sorted(d.items(), key=lambda x: x[1], reverse=True))
75
+
76
+ def get_repo_likers(repo_id: str, repo_type: str="model"):
77
+ try:
78
+ api = HfApi()
79
+ user_list = []
80
+ users = api.list_repo_likers(repo_id=repo_id, repo_type=repo_type)
81
+ for user in users:
82
+ user_list.append(user.username)
83
+ return user_list
84
+ except Exception as e:
85
+ print(e)
86
+ raise Exception(e)
87
+
88
+ def get_liked_repos(users: list[str]):
89
+ try:
90
+ api = HfApi()
91
+ likes_dict = {}
92
+ types_dict = {}
93
+ for user in users:
94
+ likes = api.list_liked_repos(user=user)
95
+ for id in likes.models:
96
+ likes_dict[id] = likes_dict.get(id, 1) + 1
97
+ types_dict[id] = "model"
98
+ for id in likes.datasets:
99
+ likes_dict[id] = likes_dict.get(id, 1) + 1
100
+ types_dict[id] = "dataset"
101
+ for id in likes.spaces:
102
+ likes_dict[id] = likes_dict.get(id, 1) + 1
103
+ types_dict[id] = "space"
104
+ likes_dict = sort_dict(likes_dict)
105
+ likes_list = list(likes_dict.keys())
106
+ types_list = [types_dict[x] for x in likes_list]
107
+ counts_list = list(likes_dict.values())
108
+ return likes_list, types_list, counts_list
109
+ except Exception as e:
110
+ print(e)
111
+ raise Exception(e)
112
+
113
+ def get_repo_collections(repo_id: str, repo_type: str="model", limit=10):
114
+ try:
115
+ api = HfApi()
116
+ if repo_type == "dataset": item = f"datasets/{repo_id}"
117
+ elif repo_type == "space": item = f"spaces/{repo_id}"
118
+ else: item = f"models/{repo_id}"
119
+ cols_dict = {}
120
+ types_dict = {}
121
+ cols = api.list_collections(item=item, sort="upvotes", limit=limit)
122
+ for c in cols:
123
+ col = api.get_collection(collection_slug=c.slug)
124
+ for i in col.items:
125
+ if i.item_type == "paper": continue
126
+ id = i.item_id
127
+ cols_dict[id] = cols_dict.get(id, 1) + 1
128
+ types_dict[id] = i.item_type
129
+ cols_dict = sort_dict(cols_dict)
130
+ cols_list = list(cols_dict.keys())
131
+ types_list = [types_dict[x] for x in cols_list]
132
+ counts_list = list(cols_dict.values())
133
+ return cols_list, types_list, counts_list
134
+ except Exception as e:
135
+ print(e)
136
+ raise Exception(e)
137
+
138
+ def get_users_collections(users: list[str], limit=10):
139
+ try:
140
+ api = HfApi()
141
+ cols_dict = {}
142
+ types_dict = {}
143
+ for user in users[0:6]:
144
+ cols = api.list_collections(owner=user, sort="upvotes", limit=limit)
145
+ for c in cols:
146
+ col = api.get_collection(collection_slug=c.slug)
147
+ for i in col.items:
148
+ if i.item_type == "paper": continue
149
+ id = i.item_id
150
+ cols_dict[id] = cols_dict.get(id, 1) + 1
151
+ types_dict[id] = i.item_type
152
+ cols_dict = sort_dict(cols_dict)
153
+ cols_list = list(cols_dict.keys())
154
+ types_list = [types_dict[x] for x in cols_list]
155
+ counts_list = list(cols_dict.values())
156
+ return cols_list, types_list, counts_list
157
+ except Exception as e:
158
+ print(e)
159
+ raise Exception(e)
160
+
161
+ def get_ref_repos(repo_id: str):
162
+ refs = {}
163
+ types = {}
164
+ repo_type = get_repo_type(repo_id)
165
+ likers = get_repo_likers(repo_id, repo_type)[0:10]
166
+ for i, t, c in zip(*get_liked_repos(likers)):
167
+ refs[i] = refs.get(i, 0) + c * 2
168
+ types[i] = t
169
+ for i, t, c in zip(*get_repo_collections(repo_id, repo_type)):
170
+ refs[i] = refs.get(i, 0) + c * 5
171
+ types[i] = t
172
+ refs = sort_dict(refs)
173
+ if repo_id in refs.keys(): refs.pop(repo_id)
174
+ refs_list = list(refs.keys())
175
+ types_list = [types[x] for x in refs_list]
176
+ counts_list = list(refs.values())
177
+ return refs_list, types_list, counts_list
178
+
179
+ def str_to_list(s: str):
180
+ try:
181
+ m = re.split("\n", s)
182
+ return [s.strip() for s in list(m)]
183
+ except Exception:
184
+ return []
185
+
186
+ def is_valid_arg(s: str):
187
+ return len(str_to_list(s)) > 0
188
+
189
+ def get_labels():
190
+ return list(RESULT_ITEMS.keys())
191
+
192
+ def get_valid_labels():
193
+ return [k for k in list(RESULT_ITEMS.keys()) if RESULT_ITEMS[k][2]]
194
+
195
+ def date_to_str(dt: datetime.datetime):
196
+ return dt.strftime('%Y-%m-%d %H:%M')
197
+
198
+ class Labels():
199
+ VALID_DTYPE = ["str", "number", "bool", "date", "markdown"]
200
+
201
+ def __init__(self):
202
+ self.types = {}
203
+ self.orders = {}
204
+ self.widths = {}
205
+
206
+ def set(self, label: str):
207
+ if not label in RESULT_ITEMS.keys(): raise Exception(f"Invalid item: {label}")
208
+ item = RESULT_ITEMS.get(label)
209
+ if item[1] not in self.VALID_DTYPE: raise Exception(f"Invalid data type: {type}")
210
+ self.types[label] = item[1]
211
+ self.orders[label] = item[0]
212
+ if len(item) > 3: self.widths[label] = item[3]
213
+ else: self.widths[label] = "10%"
214
+
215
+ def get(self):
216
+ labels = list(self.types.keys())
217
+ labels.sort(key=lambda x: self.orders[x])
218
+ label_types = [self.types[s] for s in labels]
219
+ return labels, label_types
220
+
221
+ def get_widths(self):
222
+ labels = list(self.types.keys())
223
+ label_widths = [self.widths[s] for s in labels]
224
+ return label_widths
225
+
226
+ def get_null_value(self, type: str):
227
+ if type == "bool": return False
228
+ elif type == "number" or type == "date": return 0
229
+ else: return "None"
230
+
231
+ # https://huggingface.co/docs/huggingface_hub/package_reference/hf_api
232
+ # https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.ModelInfo
233
+ class HFSearchResult():
234
+ def __init__(self):
235
+ self.labels = Labels()
236
+ self.current_item = {}
237
+ self.current_item_info = None
238
+ self.item_list = []
239
+ self.item_info_list = []
240
+ self.item_hide_flags = []
241
+ self.hide_labels = []
242
+ self.show_labels = []
243
+ self.filter_items = None
244
+ self.filters = None
245
+ gc.collect()
246
+
247
+ def reset(self):
248
+ self.__init__()
249
+
250
+ def _set(self, data, label: str):
251
+ self.labels.set(label)
252
+ self.current_item[label] = data
253
+
254
+ def _next(self):
255
+ self.item_list.append(self.current_item.copy())
256
+ self.current_item = {}
257
+ self.item_info_list.append(self.current_item_info)
258
+ self.current_item_info = None
259
+ self.item_hide_flags.append(False)
260
+
261
+ def add_item(self, i: Union[ModelInfo, DatasetInfo, SpaceInfo]):
262
+ self.current_item_info = i
263
+ if isinstance(i, ModelInfo): type = "model"
264
+ elif isinstance(i, DatasetInfo): type = "dataset"
265
+ elif isinstance(i, SpaceInfo): type = "space"
266
+ else: return
267
+ self._set(type, "Type")
268
+ self._set(i.id, "ID")
269
+ if i.likes is not None: self._set(i.likes, "Likes")
270
+ if i.last_modified is not None: self._set(date_to_str(i.last_modified), "LastMod.")
271
+ if i.trending_score is not None: self._set(int(i.trending_score), "Trending")
272
+ if i.tags is not None: self._set("True" if "not-for-all-audiences" in i.tags else "False", "NFAA")
273
+ if type in ["model", "dataset"]:
274
+ if i.gated is not None: self._set(i.gated if i.gated else "off", "Gated")
275
+ if i.downloads is not None: self._set(i.downloads, "DLs")
276
+ if i.downloads_all_time is not None: self._set(i.downloads_all_time, "AllDLs")
277
+ if type == "model":
278
+ if i.inference is not None: self._set(i.inference, "Status")
279
+ if i.library_name is not None: self._set(i.library_name, "Library")
280
+ if i.pipeline_tag is not None: self._set(i.pipeline_tag, "Pipeline")
281
+ if type == "space":
282
+ if i.runtime is not None:
283
+ self._set(i.runtime.hardware, "Hardware")
284
+ self._set(i.runtime.stage, "Stage")
285
+ self._next()
286
+
287
+ def search(self, repo_types: list, sort: str, sort_method: str, filter_str: str, search_str: str, author: str, tags: str, infer: str, gated: str, appr: list[str],
288
+ size_categories: list, limit: int, hardware: list, stage: list, fetch_detail: list, show_labels: list):
289
+ try:
290
+ self.reset()
291
+ self.show_labels = show_labels.copy()
292
+ api = HfApi()
293
+ kwargs = {}
294
+ mkwargs = {}
295
+ dkwargs = {}
296
+ skwargs = {}
297
+ if filter_str: kwargs["filter"] = str_to_list(filter_str)
298
+ if search_str: kwargs["search"] = search_str
299
+ if author: kwargs["author"] = author
300
+ if tags and is_valid_arg(tags):
301
+ mkwargs["tags"] = str_to_list(tags)
302
+ dkwargs["tags"] = str_to_list(tags)
303
+ if limit > 0: kwargs["limit"] = limit
304
+ if sort_method == "descending order": kwargs["direction"] = -1
305
+ if gated == "gated":
306
+ mkwargs["gated"] = True
307
+ dkwargs["gated"] = True
308
+ elif gated == "non-gated":
309
+ mkwargs["gated"] = False
310
+ dkwargs["gated"] = False
311
+ mkwargs["sort"] = sort
312
+ if len(size_categories) > 0: dkwargs["size_categories"] = size_categories
313
+ if infer != "all": mkwargs["inference"] = infer
314
+ if "model" in repo_types:
315
+ models = api.list_models(full=True, cardData=True, **kwargs, **mkwargs)
316
+ for model in models:
317
+ if model.gated is not None and model.gated and model.gated not in appr: continue
318
+ self.add_item(model)
319
+ if "dataset" in repo_types:
320
+ datasets = api.list_datasets(full=True, **kwargs, **dkwargs)
321
+ for dataset in datasets:
322
+ if dataset.gated is not None and dataset.gated and dataset.gated not in appr: continue
323
+ self.add_item(dataset)
324
+ if "space" in repo_types:
325
+ if "Space Runtime" in fetch_detail:
326
+ spaces = api.list_spaces(expand=["cardData", "datasets", "disabled", "lastModified", "createdAt",
327
+ "likes", "models", "private", "runtime", "sdk", "sha", "tags", "trendingScore"], **kwargs, **skwargs)
328
+ else: spaces = api.list_spaces(full=True, **kwargs, **skwargs)
329
+ for space in spaces:
330
+ if space.gated is not None and space.gated and space.gated not in appr: continue
331
+ if space.runtime is not None:
332
+ if len(hardware) > 0 and space.runtime.stage == "RUNNING" and space.runtime.hardware not in hardware: continue
333
+ if len(stage) > 0 and space.runtime.stage not in stage: continue
334
+ self.add_item(space)
335
+ if sort == "downloads" and ("space" not in repo_types): self.sort("DLs")
336
+ elif sort == "downloads_all_time" and ("space" not in repo_types): self.sort("AllDLs")
337
+ elif sort == "likes": self.sort("Likes")
338
+ elif sort == "trending_score": self.sort("Trending")
339
+ else: self.sort("LastMod.")
340
+ except Exception as e:
341
+ raise Exception(f"Search error: {e}") from e
342
+
343
+ def search_ref_repos(self, repo_id: str, repo_types: str, sort: str, show_labels: list, limit=10):
344
+ try:
345
+ self.reset()
346
+ self.show_labels = show_labels.copy()
347
+ api = HfApi()
348
+ repos, types, counts = get_ref_repos(repo_id)
349
+ i = 0
350
+ for r, t in zip(repos, types):
351
+ if i + 1 > limit: break
352
+ i += 1
353
+ if t not in repo_types: continue
354
+ info = api.repo_info(repo_id=r, repo_type=t)
355
+ if info: self.add_item(info)
356
+ if sort == "downloads" and ("space" not in repo_types): self.sort("DLs")
357
+ elif sort == "downloads_all_time" and ("space" not in repo_types): self.sort("AllDLs")
358
+ elif sort == "likes": self.sort("Likes")
359
+ elif sort == "trending_score": self.sort("Trending")
360
+ else: self.sort("LastMod.")
361
+ except Exception as e:
362
+ raise Exception(f"Search error: {e}") from e
363
+
364
+ def get(self):
365
+ labels, label_types = self.labels.get()
366
+ self._do_filter()
367
+ dflist = [[item.get(l, self.labels.get_null_value(t)) for l, t in zip(labels, label_types)] for item, is_hide in zip(self.item_list, self.item_hide_flags) if not is_hide]
368
+ df = self._to_pandas(dflist, labels)
369
+ show_label_types = [t for l, t in zip(labels, label_types) if l not in self.hide_labels and l in self.show_labels]
370
+ show_labels = [l for l in labels if l not in self.hide_labels and l in self.show_labels]
371
+ return df, show_labels, show_label_types
372
+
373
+ def _to_pandas(self, dflist: list, labels: list):
374
+ # https://pandas.pydata.org/docs/reference/api/pandas.io.formats.style.Styler.apply.html
375
+ # https://stackoverflow.com/questions/41654949/pandas-style-function-to-highlight-specific-columns
376
+ # https://stackoverflow.com/questions/69832206/pandas-styling-with-conditional-rules
377
+ # https://stackoverflow.com/questions/41203959/conditionally-format-python-pandas-cell
378
+ # https://stackoverflow.com/questions/51187868/how-do-i-remove-and-re-sort-reindex-columns-after-applying-style-in-python-pan
379
+ # https://stackoverflow.com/questions/36921951/truth-value-of-a-series-is-ambiguous-use-a-empty-a-bool-a-item-a-any-o
380
+ def rank_df(sdf: pd.DataFrame, df: pd.DataFrame, col: str):
381
+ ranks = [(0.5, "gold"), (0.75, "orange"), (0.9, "orangered")]
382
+ for t, color in ranks:
383
+ sdf.loc[df[col] >= df[col].quantile(q=t), [col]] = f'color: {color}'
384
+ return sdf
385
+
386
+ def highlight_df(x: pd.DataFrame, df: pd.DataFrame):
387
+ sdf = pd.DataFrame("", index=x.copy().index, columns=x.copy().columns)
388
+ columns = df.columns
389
+ if "Trending" in columns: sdf = rank_df(sdf, df, "Trending")
390
+ if "Likes" in columns: sdf = rank_df(sdf, df, "Likes")
391
+ if "AllDLs" in columns: sdf = rank_df(sdf, df, "AllDLs")
392
+ if "DLs" in columns: sdf = rank_df(sdf, df, "DLs")
393
+ if "Status" in columns:
394
+ sdf.loc[df["Status"] == "warm", ["Type"]] = 'color: orange'
395
+ sdf.loc[df["Status"] == "cold", ["Type"]] = 'color: dodgerblue'
396
+ if "Gated" in columns:
397
+ sdf.loc[df["Gated"] == "auto", ["Gated"]] = 'color: dodgerblue'
398
+ sdf.loc[df["Gated"] == "manual", ["Gated"]] = 'color: crimson'
399
+ if "Stage" in columns and "Hardware" in columns:
400
+ sdf.loc[(df["Stage"] == "RUNNING") & (df["Hardware"] != "zero-a10g") & (df["Hardware"] != "cpu-basic") & (df["Hardware"] != "None") & (df["Hardware"]), ["Hardware", "Type"]] = 'color: lime'
401
+ sdf.loc[(df["Stage"] == "RUNNING") & (df["Hardware"] == "zero-a10g"), ["Hardware", "Type"]] = 'color: green'
402
+ sdf.loc[(df["Type"] == "space") & (df["Stage"] != "RUNNING")] = 'opacity: 0.5'
403
+ sdf.loc[(df["Type"] == "space") & (df["Stage"] != "RUNNING"), ["Type"]] = 'color: crimson'
404
+ sdf.loc[df["Stage"] == "RUNNING", ["Stage"]] = 'color: lime'
405
+ if "NFAA" in columns: sdf.loc[df["NFAA"] == "True", ["Type"]] = 'background-color: hotpink'
406
+ show_columns = x.copy().columns
407
+ style_columns = sdf.columns
408
+ drop_columns = [c for c in style_columns if c not in show_columns]
409
+ sdf = sdf.drop(drop_columns, axis=1)
410
+ return sdf
411
+
412
+ def id_to_md(df: pd.DataFrame):
413
+ if df["Type"] == "dataset": return f'[{df["ID"]}](https://hf.co/datasets/{df["ID"]})'
414
+ elif df["Type"] == "space": return f'[{df["ID"]}](https://hf.co/spaces/{df["ID"]})'
415
+ else: return f'[{df["ID"]}](https://hf.co/{df["ID"]})'
416
+
417
+ def format_md_df(df: pd.DataFrame):
418
+ df["ID"] = df.apply(id_to_md, axis=1)
419
+ return df
420
+
421
+ hide_labels = [l for l in labels if l in self.hide_labels or l not in self.show_labels]
422
+ df = format_md_df(pd.DataFrame(dflist, columns=labels))
423
+ ref_df = df.copy()
424
+ df = df.drop(hide_labels, axis=1).style.apply(highlight_df, axis=None, df=ref_df)
425
+ return df
426
+
427
+ def set_hide(self, hide_labels: list):
428
+ self.hide_labels = hide_labels.copy()
429
+
430
+ def set_filter(self, filter_item1: str, filter1: str):
431
+ if not filter_item1 and not filter1:
432
+ self.filter_items = None
433
+ self.filters = None
434
+ else:
435
+ self.filter_items = [filter_item1]
436
+ self.filters = [filter1]
437
+
438
+ def _do_filter(self):
439
+ if self.filters is None or self.filter_items is None:
440
+ self.item_hide_flags = [False] * len(self.item_list)
441
+ return
442
+ labels, label_types = self.labels.get()
443
+ types = dict(zip(labels, label_types))
444
+ flags = []
445
+ for item in self.item_list:
446
+ flag = False
447
+ for i, f in zip(self.filter_items, self.filters):
448
+ if i not in item.keys(): continue
449
+ t = types[i]
450
+ if item[i] == self.labels.get_null_value(t):
451
+ flag = True
452
+ break
453
+ if t in set(["str", "markdown"]):
454
+ if f in item[i]: flag = False
455
+ else:
456
+ flag = True
457
+ break
458
+ flags.append(flag)
459
+ self.item_hide_flags = flags
460
+
461
+ def sort(self, key="Likes"):
462
+ if len(self.item_list) == 0: raise Exception("No item found.")
463
+ if not key in self.labels.get()[0]: key = "Likes"
464
+ self.item_list, self.item_hide_flags, self.item_info_list = zip(*sorted(zip(self.item_list, self.item_hide_flags, self.item_info_list), key=lambda x: x[0][key], reverse=True))
465
+
466
+ def get_gr_df(self):
467
+ df, labels, label_types = self.get()
468
+ widths = self.labels.get_widths()
469
+ return gr.update(type="pandas", value=df, headers=labels, datatype=label_types, column_widths=widths, wrap=True)
470
+
471
+ def get_gr_hide_labels(self):
472
+ return gr.update(choices=self.labels.get()[0], value=[], visible=True)
473
+
474
+ def get_gr_filter_item(self, filter_item: str=""):
475
+ labels, label_types = self.labels.get()
476
+ choices = [s for s, t in zip(labels, label_types) if t in set(["str", "markdown"])]
477
+ if len(choices) == 0: choices = [""]
478
+ return gr.update(choices=choices, value=filter_item if filter_item else choices[0], visible=True)
479
+
480
+ def get_gr_filter(self, filter_item: str=""):
481
+ labels = self.labels.get()[0]
482
+ if not filter_item or filter_item not in set(labels): return gr.update(choices=[""], value="", visible=True)
483
+ d = {}
484
+ for item in self.item_list:
485
+ if filter_item not in item.keys(): continue
486
+ v = item[filter_item]
487
+ if v in d.keys(): d[v] += 1
488
+ else: d[v] = 1
489
+ return gr.update(choices=[""] + [t[0] for t in sorted(d.items(), key=lambda x : x[1])][:100], value="", visible=True)
490
+
491
+ def search(repo_types: list, sort: str, sort_method: str, filter_str: str, search_str: str, author: str, tags: str, infer: str,
492
+ gated: str, appr: list[str], size_categories: list, limit: int, hardware: list, stage: list, fetch_detail: list, show_labels: list, r: HFSearchResult):
493
+ try:
494
+ r.search(repo_types, sort, sort_method, filter_str, search_str, author, tags, infer, gated, appr, size_categories,
495
+ limit, hardware, stage, fetch_detail, show_labels)
496
+ return r.get_gr_df(), r.get_gr_hide_labels(), r
497
+ except Exception as e:
498
+ raise gr.Error(e)
499
+
500
+ def search_ref_repos(repo_id: str, repo_types: list, sort: str, show_labels: list, limit, r: HFSearchResult):
501
+ try:
502
+ if not repo_id: raise gr.Error("Input Repo ID")
503
+ r.search_ref_repos(repo_id, repo_types, sort, show_labels, limit)
504
+ return r.get_gr_df(), r.get_gr_hide_labels(), r
505
+ except Exception as e:
506
+ raise gr.Error(e)
507
+
508
+ def update_df(hide_labels: list, filter_item1: str, filter1: str, r: HFSearchResult):
509
+ r.set_hide(hide_labels)
510
+ r.set_filter(filter_item1, filter1)
511
+ return r.get_gr_df(), r
512
+
513
+ def update_filter(filter_item1: str, r: HFSearchResult):
514
+ return r.get_gr_filter_item(filter_item1), r.get_gr_filter(filter_item1), gr.update(visible=True), r
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip>=24.1
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ huggingface_hub
2
+ gradio_huggingfacehub_search
subtags.json ADDED
The diff for this file is too large to render. See raw diff
 
tags.json ADDED
The diff for this file is too large to render. See raw diff