fix cast index and no-labels errors
Browse files
.gitignore
CHANGED
@@ -129,6 +129,7 @@ venv/
|
|
129 |
ENV/
|
130 |
env.bak/
|
131 |
venv.bak/
|
|
|
132 |
|
133 |
# Spyder project settings
|
134 |
.spyderproject
|
|
|
129 |
ENV/
|
130 |
env.bak/
|
131 |
venv.bak/
|
132 |
+
.python-version
|
133 |
|
134 |
# Spyder project settings
|
135 |
.spyderproject
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -64,7 +64,7 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
|
|
64 |
progress(1.0, desc="Prompt generated")
|
65 |
data = json.loads(result)
|
66 |
system_prompt = data["classification_task"]
|
67 |
-
labels = data["labels"]
|
68 |
return system_prompt, labels
|
69 |
|
70 |
|
@@ -177,14 +177,20 @@ def generate_dataset(
|
|
177 |
distiset_results.append(record)
|
178 |
|
179 |
dataframe = pd.DataFrame(distiset_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
if multi_label:
|
181 |
dataframe["labels"] = dataframe["labels"].apply(
|
182 |
lambda x: list(
|
183 |
set(
|
184 |
[
|
185 |
-
label.lower().strip()
|
186 |
for label in x
|
187 |
-
if label is not None and label.lower().strip() in labels
|
188 |
]
|
189 |
)
|
190 |
)
|
@@ -214,6 +220,7 @@ def push_dataset_to_hub(
|
|
214 |
pipeline_code: str = "",
|
215 |
progress=gr.Progress(),
|
216 |
):
|
|
|
217 |
progress(0.0, desc="Validating")
|
218 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
219 |
progress(0.3, desc="Preprocessing")
|
@@ -230,7 +237,10 @@ def push_dataset_to_hub(
|
|
230 |
features = Features(
|
231 |
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
232 |
)
|
233 |
-
dataset = Dataset.from_pandas(
|
|
|
|
|
|
|
234 |
dataset = combine_datasets(repo_id, dataset)
|
235 |
distiset = Distiset({"default": dataset})
|
236 |
progress(0.9, desc="Pushing dataset")
|
@@ -269,6 +279,7 @@ def push_dataset(
|
|
269 |
num_rows=num_rows,
|
270 |
temperature=temperature,
|
271 |
)
|
|
|
272 |
push_dataset_to_hub(
|
273 |
dataframe,
|
274 |
org_name,
|
@@ -365,7 +376,7 @@ def push_dataset(
|
|
365 |
and all(label in labels for label in sample["labels"])
|
366 |
)
|
367 |
)
|
368 |
-
else
|
369 |
),
|
370 |
)
|
371 |
for sample in hf_dataset
|
|
|
64 |
progress(1.0, desc="Prompt generated")
|
65 |
data = json.loads(result)
|
66 |
system_prompt = data["classification_task"]
|
67 |
+
labels = get_preprocess_labels(data["labels"])
|
68 |
return system_prompt, labels
|
69 |
|
70 |
|
|
|
177 |
distiset_results.append(record)
|
178 |
|
179 |
dataframe = pd.DataFrame(distiset_results)
|
180 |
+
if (
|
181 |
+
not labels
|
182 |
+
or len(set(label.lower().strip() for label in labels if label.strip())) < 2
|
183 |
+
):
|
184 |
+
raise gr.Error(
|
185 |
+
"Please provide at least 2 unique, non-empty labels to classify your text."
|
186 |
+
)
|
187 |
if multi_label:
|
188 |
dataframe["labels"] = dataframe["labels"].apply(
|
189 |
lambda x: list(
|
190 |
set(
|
191 |
[
|
192 |
+
label.lower().strip() if (label is not None and label.lower().strip() in labels) else random.choice(labels)
|
193 |
for label in x
|
|
|
194 |
]
|
195 |
)
|
196 |
)
|
|
|
220 |
pipeline_code: str = "",
|
221 |
progress=gr.Progress(),
|
222 |
):
|
223 |
+
gr.Info(message=f"Dataframe columns in push dataset to hub: {dataframe.columns}", duration=20)
|
224 |
progress(0.0, desc="Validating")
|
225 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
226 |
progress(0.3, desc="Preprocessing")
|
|
|
237 |
features = Features(
|
238 |
{"text": Value("string"), "label": ClassLabel(names=labels)}
|
239 |
)
|
240 |
+
dataset = Dataset.from_pandas(
|
241 |
+
dataframe.reset_index(drop=True),
|
242 |
+
features=features,
|
243 |
+
)
|
244 |
dataset = combine_datasets(repo_id, dataset)
|
245 |
distiset = Distiset({"default": dataset})
|
246 |
progress(0.9, desc="Pushing dataset")
|
|
|
279 |
num_rows=num_rows,
|
280 |
temperature=temperature,
|
281 |
)
|
282 |
+
gr.Info(message=f"Dataframe columns: {dataframe.columns}", duration=20)
|
283 |
push_dataset_to_hub(
|
284 |
dataframe,
|
285 |
org_name,
|
|
|
376 |
and all(label in labels for label in sample["labels"])
|
377 |
)
|
378 |
)
|
379 |
+
else None
|
380 |
),
|
381 |
)
|
382 |
for sample in hf_dataset
|