marksverdhei commited on
Commit
05dd656
1 Parent(s): c4e154c

Add correct device

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. views.py +3 -3
app.py CHANGED
@@ -45,6 +45,7 @@ with tab1:
45
  vectors_2d=vectors_2d,
46
  reducer=reducer,
47
  corrector=corrector,
 
48
  )
49
 
50
  with tab2:
 
45
  vectors_2d=vectors_2d,
46
  reducer=reducer,
47
  corrector=corrector,
48
+ device=device,
49
  )
50
 
51
  with tab2:
views.py CHANGED
@@ -53,7 +53,7 @@ def diffs(embeddings: np.ndarray, corrector, encoder: PreTrainedModel, tokenizer
53
 
54
  # st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')
55
 
56
- def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector):
57
 
58
 
59
  # Add a scatter plot using Plotly
@@ -83,7 +83,7 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
83
 
84
  with col1:
85
  # Main content stays here (scatterplot, form, etc.)
86
- selected_points = plotly_events(fig, click_event=True, hover_event=False, #override_height=600, override_width="100%"
87
  )
88
  with st.form(key="form1_main"):
89
  if selected_points:
@@ -103,7 +103,7 @@ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, correcto
103
  inferred_embedding = inferred_embedding.astype("float32")
104
 
105
  inversion_output_text, = vec2text.invert_embeddings(
106
- embeddings=torch.tensor(inferred_embedding).cuda(),
107
  corrector=corrector,
108
  num_steps=20,
109
  )
 
53
 
54
  # st.html('<a href="https://www.flaticon.com/free-icons/array" title="array icons">Array icons created by Voysla - Flaticon</a>')
55
 
56
+ def plot(df: pd.DataFrame, embeddings: np.ndarray, vectors_2d, reducer, corrector, device):
57
 
58
 
59
  # Add a scatter plot using Plotly
 
83
 
84
  with col1:
85
  # Main content stays here (scatterplot, form, etc.)
86
+ selected_points = plotly_events(fig, click_event=True, hover_event=False,# override_height="600", override_width="600"
87
  )
88
  with st.form(key="form1_main"):
89
  if selected_points:
 
103
  inferred_embedding = inferred_embedding.astype("float32")
104
 
105
  inversion_output_text, = vec2text.invert_embeddings(
106
+ embeddings=torch.tensor(inferred_embedding).to(device),
107
  corrector=corrector,
108
  num_steps=20,
109
  )