File size: 3,844 Bytes
9ae1b66
 
 
 
6621d73
9ae1b66
 
 
 
 
 
 
3d12d3a
 
6f507b4
6621d73
356f92c
 
9ae1b66
 
 
 
ef9cbc8
9ae1b66
9930cd7
9ae1b66
9930cd7
9ae1b66
9930cd7
9ae1b66
9930cd7
9ae1b66
 
 
 
3772eaf
356f92c
3772eaf
9ae1b66
 
 
 
6621d73
 
 
9ae1b66
 
 
ef9cbc8
9ae1b66
 
 
 
 
 
 
 
 
ef9cbc8
9ae1b66
 
 
3772eaf
9ae1b66
 
ef9cbc8
 
9ae1b66
 
6621d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3772eaf
a6deb48
9ae1b66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os

import numpy as np
import pandas as pd
import requests
from datasets import Dataset, DownloadMode, load_dataset
from gradio_client import Client

from src.my_logger import setup_logger

SUBREDDIT = os.environ["SUBREDDIT"]
USERNAME = os.environ["USERNAME"]
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
embeddings_space = f"derek-thomas/nomic-embeddings"
FILTER_IDS_URL = "https://huggingface.co./spaces/reddit-tools-HF/dataset-creator-reddit-bestofredditorupdates/raw/main/filter_ids.json"
HF_TOKEN = os.environ.get("HF_TOKEN")


logger = setup_logger(__name__)


def load_datasets():
    # Get latest datasets locally
    logger.info(f"Trying to download {PROCESSED_DATASET}")
    dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
    logger.info(f"Loaded {PROCESSED_DATASET}")

    logger.info(f"Trying to download {OG_DATASET}")
    original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
    logger.info(f"Loaded {OG_DATASET}")
    return dataset, original_dataset


def merge_and_update_datasets(dataset, original_dataset):
    # Get client
    client = Client(embeddings_space, hf_token=HF_TOKEN)

    # Merge and figure out which rows need to be updated with embeddings
    odf = original_dataset['train'].to_pandas()
    df = dataset['train'].to_pandas()

    # Filter ODF in-case we missed any
    odf = remove_filtered_rows(odf, FILTER_IDS_URL)

    # Step 1: Merge df onto odf
    # We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
    merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
    updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf])

    # Step 2: Compare 'content' from odf and df, update 'embedding' if they differ
    merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding'])

    # Step 3: Cleanup - keep only the necessary columns.
    # Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest
    merged_df = merged_df.drop(columns=['content', 'new', 'updated'])  # Update columns to match df
    merged_df.rename(columns={'content_odf': 'content'}, inplace=True)  # Rename 'content_odf' back to 'content'

    logger.info(f"Updating {updated_row_count} rows...")
    # Iterate over the DataFrame rows where 'embedding' is None
    for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
        # Update 'embedding' for the current row using our function
        merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)

    dataset['train'] = Dataset.from_pandas(merged_df)
    logger.info(f"Updated {updated_row_count} rows")
    return dataset, updated_row_count


def remove_filtered_rows(df: pd.DataFrame, url: str) -> pd.DataFrame:
    """
    Removes rows from the DataFrame where the 'id' is present in the JSON file at the given URL.

    :param df: Input DataFrame to be filtered.
    :param url: URL to the JSON file containing the filter IDs.
    :return: DataFrame with rows containing IDs present in the JSON file removed.
    """

    # Load filter IDs from JSON file at the URL
    response = requests.get(url)
    filter_ids = response.json()

    logger.info(f"Loaded {len(filter_ids)} IDs from {url}")

    # Remove the rows with IDs present in filter_ids
    filtered_df = df[~df['id'].astype(str).isin(filter_ids)]

    logger.info(f"Filtered {len(df) - len(filtered_df)} rows from the DataFrame")

    return filtered_df


def update_embeddings(content, client):
    embedding = client.predict('search_document: ' + content, api_name="/embed")
    return np.array(embedding)