Spaces:
Runtime error
Runtime error
Commit
•
3772eaf
1
Parent(s):
7d5ff0e
Move client instantiation
Browse files- src/utilities.py +5 -3
src/utilities.py
CHANGED
@@ -12,7 +12,6 @@ USERNAME = os.environ["USERNAME"]
|
|
12 |
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
|
13 |
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
|
14 |
|
15 |
-
client = Client("derek-thomas/nomic-embeddings")
|
16 |
logger = setup_logger(__name__)
|
17 |
|
18 |
|
@@ -29,6 +28,9 @@ async def load_datasets():
|
|
29 |
|
30 |
|
31 |
def merge_and_update_datasets(dataset, original_dataset):
|
|
|
|
|
|
|
32 |
# Merge and figure out which rows need to be updated with embeddings
|
33 |
odf = original_dataset['train'].to_pandas()
|
34 |
df = dataset['train'].to_pandas()
|
@@ -50,13 +52,13 @@ def merge_and_update_datasets(dataset, original_dataset):
|
|
50 |
# Iterate over the DataFrame rows where 'embedding' is None
|
51 |
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
|
52 |
# Update 'embedding' for the current row using our function
|
53 |
-
merged_df.at[index, 'embedding'] = update_embeddings(row['content'])
|
54 |
|
55 |
dataset['train'] = Dataset.from_pandas(merged_df)
|
56 |
logger.info(f"Updated {updated_rows} rows")
|
57 |
return dataset
|
58 |
|
59 |
|
60 |
-
def update_embeddings(content):
|
61 |
embedding = client.predict(content, api_name="/embed")
|
62 |
return np.array(embedding)
|
|
|
12 |
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
|
13 |
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
|
14 |
|
|
|
15 |
logger = setup_logger(__name__)
|
16 |
|
17 |
|
|
|
28 |
|
29 |
|
30 |
def merge_and_update_datasets(dataset, original_dataset):
|
31 |
+
# Get client
|
32 |
+
client = Client("derek-thomas/nomic-embeddings")
|
33 |
+
|
34 |
# Merge and figure out which rows need to be updated with embeddings
|
35 |
odf = original_dataset['train'].to_pandas()
|
36 |
df = dataset['train'].to_pandas()
|
|
|
52 |
# Iterate over the DataFrame rows where 'embedding' is None
|
53 |
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
|
54 |
# Update 'embedding' for the current row using our function
|
55 |
+
merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)
|
56 |
|
57 |
dataset['train'] = Dataset.from_pandas(merged_df)
|
58 |
logger.info(f"Updated {updated_rows} rows")
|
59 |
return dataset
|
60 |
|
61 |
|
62 |
+
def update_embeddings(content, client):
|
63 |
embedding = client.predict(content, api_name="/embed")
|
64 |
return np.array(embedding)
|