marksverdhei's picture
Add correct device
05dd656
import streamlit as st
import vec2text
import torch
from umap import UMAP
import plotly.express as px
import numpy as np
from streamlit_plotly_events import plotly_events
import utils
import pandas as pd
from scipy.spatial import distance
from resources import get_gtr_embeddings
from transformers import PreTrainedModel, PreTrainedTokenizer
dimensionality_reduction_model_name = "PCA"
def diffs(embeddings: np.ndarray, corrector, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer):
st.title('"A man is to king, what woman is to queen"')
st.markdown("A well known pehnomenon in semantic vectors is the way we can do vector operations like addition and subtraction to find spacial relations in the vector space.")
st.markdown(
'In word embedding models, we have found that the relationship between words can be captured mathematically, '
'such that "king" is to "man" as "queen" is to "woman," demonstrating that vector arithmetic can encode analogies and semantic relationships in high-dimensional space ([Mikolov et al., 2013](https://arxiv.org/abs/1301.3781)).'
)
st.markdown("This application lets you freely explore to which extent that property applies to embedding inversion models given the other factors of inaccuracy")
generated_sentence = ""
device = encoder.device
with st.form(key="foo") as form:
submit_button = st.form_submit_button("Synthesize")
sent1 = st.text_input("Sentence 1", value="I am a king")
st.latex("-")
sent2 = st.text_input("Sentence 2", value="I am a man")
st.latex("+")
sent3 = st.text_input("Sentence 3", value="I am a woman")
st.latex("=")
if submit_button:
v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer, device=encoder.device).to(device)
v4 = v1 - v2 + v3
generated_sentence, = vec2text.invert_embeddings(
embeddings=v4.unsqueeze(0).to(device),
corrector=corrector,
num_steps=20,
)
generated_sentence = generated_sentence.strip()
sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)
if submit_button:
generated_sentence = "HI!"
# st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')
def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector, device):
# Add a scatter plot using Plotly
fig = px.scatter(
x=vectors_2d[:, 0],
y=vectors_2d[:, 1],
opacity=0.6,
hover_data={"Title": df["title"]},
labels={'x': f'{dimensionality_reduction_model_name} Component 1', 'y': f'{dimensionality_reduction_model_name} Component 2'},
title="UMAP Scatter Plot of Reddit Titles",
color_discrete_sequence=["#ff504c"] # Set default blue color for points
)
# Customize the layout to adapt to browser settings (light/dark mode)
fig.update_layout(
template=None, # Let Plotly adapt automatically based on user settings
plot_bgcolor="rgba(0, 0, 0, 0)",
paper_bgcolor="rgba(0, 0, 0, 0)"
)
x, y = 0.0, 0.0
vec = np.array([x, y]).astype("float32")
inferred_embedding = None
# Add a card container to the right of the content with Streamlit columns
col1, col2 = st.columns([0.6, 0.4]) # Adjusting ratio to allocate space for the card container
inversion_output_text = None
with col1:
# Main content stays here (scatterplot, form, etc.)
selected_points = plotly_events(fig, click_event=True, hover_event=False,# override_height="600", override_width="600"
)
with st.form(key="form1_main"):
if selected_points:
clicked_point = selected_points[0]
x = clicked_point['x']
y = clicked_point['y']
x = st.number_input("X Coordinate", value=x, format="%.10f")
y = st.number_input("Y Coordinate", value=y, format="%.10f")
vec = np.array([x, y]).astype("float32")
submit_button = st.form_submit_button("Synthesize")
if submit_button:
inferred_embedding = reducer.inverse_transform(np.array([[x, y]]) if not isinstance(reducer, UMAP) else np.array([[x, y]]))
inferred_embedding = inferred_embedding.astype("float32")
inversion_output_text, = vec2text.invert_embeddings(
embeddings=torch.tensor(inferred_embedding).to(device),
corrector=corrector,
num_steps=20,
)
else:
st.text("Click on a point in the scatterplot to see its coordinates.")
with col2:
closest_sentence_index = utils.find_exact_match(vectors_2d, vec, decimals=3)
selected_sentence = df.title.iloc[closest_sentence_index] if closest_sentence_index > -1 else None
selected_sentence_embedding = embeddings[closest_sentence_index] if closest_sentence_index > -1 else None
st.markdown(
f"### Selected text:\n```console\n{selected_sentence}\n```"
)
st.markdown(
f"### Synthesized text:\n```console\n{inversion_output_text}\n```"
)
if inferred_embedding is not None and (closest_sentence_index != -1):
couple = selected_sentence_embedding.squeeze(), inferred_embedding.squeeze()
st.markdown("### Inferred embedding distance:")
st.number_input("Euclidean", value=distance.euclidean(
*couple
), disabled=True)
st.number_input("Cosine", value=distance.cosine(*couple), disabled=True)