marksverdhei's picture
Add explanation
6b30d5d
import streamlit as st
import pandas as pd
import torch
import vec2text
from transformers import AutoModel, AutoTokenizer
from sklearn.decomposition import PCA
from utils import file_cache
from transformers import PreTrainedModel, PreTrainedTokenizer
# Caching the vec2text corrector
@st.cache_resource
def load_corrector():
return vec2text.load_pretrained_corrector("gtr-base")
# Caching the dataframe since loading from an external source can be time-consuming
@st.cache_data
def load_data():
return pd.read_csv("https://huggingface.co./datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv")
@st.cache_resource
def vector_compressor_from_config():
# Return UMAP with 2 components for dimensionality reduction
# return UMAP(n_components=2)
return PCA(n_components=2)
@st.cache_data
@file_cache(".cache/reducer_embeddings.pickle")
def reduce_embeddings(embeddings):
reducer = vector_compressor_from_config()
return reducer.fit_transform(embeddings), reducer
# Caching the model and tokenizer to avoid reloading
@st.cache_resource
def load_model_and_tokenizer(device="cpu"):
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
return encoder, tokenizer
def get_gtr_embeddings(text_list: list[str],
encoder: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: str,
) -> torch.Tensor:
inputs = tokenizer(text_list,
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length",).to(device)
with torch.no_grad():
model_output = encoder(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
hidden_state = model_output.last_hidden_state
embeddings = vec2text.models.model_utils.mean_pool(hidden_state, inputs['attention_mask'])
return embeddings