|
import streamlit as st |
|
import pandas as pd |
|
import torch |
|
import vec2text |
|
from transformers import AutoModel, AutoTokenizer |
|
from sklearn.decomposition import PCA |
|
from utils import file_cache |
|
from transformers import PreTrainedModel, PreTrainedTokenizer |
|
|
|
|
|
@st.cache_resource |
|
def load_corrector(): |
|
return vec2text.load_pretrained_corrector("gtr-base") |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
return pd.read_csv("https://huggingface.co./datasets/marksverdhei/reddit-syac-urls/resolve/main/train.csv") |
|
|
|
|
|
@st.cache_resource |
|
def vector_compressor_from_config(): |
|
|
|
|
|
return PCA(n_components=2) |
|
|
|
|
|
@st.cache_data |
|
@file_cache(".cache/reducer_embeddings.pickle") |
|
def reduce_embeddings(embeddings): |
|
reducer = vector_compressor_from_config() |
|
return reducer.fit_transform(embeddings), reducer |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(device="cpu"): |
|
encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base") |
|
return encoder, tokenizer |
|
|
|
|
|
def get_gtr_embeddings(text_list: list[str], |
|
encoder: PreTrainedModel, |
|
tokenizer: PreTrainedTokenizer, |
|
device: str, |
|
) -> torch.Tensor: |
|
|
|
inputs = tokenizer(text_list, |
|
return_tensors="pt", |
|
max_length=128, |
|
truncation=True, |
|
padding="max_length",).to(device) |
|
|
|
with torch.no_grad(): |
|
model_output = encoder(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
|
hidden_state = model_output.last_hidden_state |
|
embeddings = vec2text.models.model_utils.mean_pool(hidden_state, inputs['attention_mask']) |
|
|
|
return embeddings |