marksverdhei
commited on
Commit
•
05dd656
1
Parent(s):
c4e154c
Add correct device
Browse files
app.py
CHANGED
@@ -45,6 +45,7 @@ with tab1:
|
|
45 |
vectors_2d=vectors_2d,
|
46 |
reducer=reducer,
|
47 |
corrector=corrector,
|
|
|
48 |
)
|
49 |
|
50 |
with tab2:
|
|
|
45 |
vectors_2d=vectors_2d,
|
46 |
reducer=reducer,
|
47 |
corrector=corrector,
|
48 |
+
device=device,
|
49 |
)
|
50 |
|
51 |
with tab2:
|
views.py
CHANGED
@@ -53,7 +53,7 @@ def diffs(embeddings: np.ndarray, corrector, encoder: PreTrainedModel, tokenizer
|
|
53 |
|
54 |
# st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')
|
55 |
|
56 |
-
def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector):
|
57 |
|
58 |
|
59 |
# Add a scatter plot using Plotly
|
@@ -83,7 +83,7 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
|
|
83 |
|
84 |
with col1:
|
85 |
# Main content stays here (scatterplot, form, etc.)
|
86 |
-
selected_points = plotly_events(fig, click_event=True, hover_event=False
|
87 |
)
|
88 |
with st.form(key="form1_main"):
|
89 |
if selected_points:
|
@@ -103,7 +103,7 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
|
|
103 |
inferred_embedding = inferred_embedding.astype("float32")
|
104 |
|
105 |
inversion_output_text, = vec2text.invert_embeddings(
|
106 |
-
embeddings=torch.tensor(inferred_embedding).
|
107 |
corrector=corrector,
|
108 |
num_steps=20,
|
109 |
)
|
|
|
53 |
|
54 |
# st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')
|
55 |
|
56 |
+
def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector, device):
|
57 |
|
58 |
|
59 |
# Add a scatter plot using Plotly
|
|
|
83 |
|
84 |
with col1:
|
85 |
# Main content stays here (scatterplot, form, etc.)
|
86 |
+
selected_points = plotly_events(fig, click_event=True, hover_event=False,# override_height="600", override_width="600"
|
87 |
)
|
88 |
with st.form(key="form1_main"):
|
89 |
if selected_points:
|
|
|
103 |
inferred_embedding = inferred_embedding.astype("float32")
|
104 |
|
105 |
inversion_output_text, = vec2text.invert_embeddings(
|
106 |
+
embeddings=torch.tensor(inferred_embedding).to(device),
|
107 |
corrector=corrector,
|
108 |
num_steps=20,
|
109 |
)
|