cheesyFishes commited on
Commit
a78f8a0
·
1 Parent(s): ffd45a6

add back CUDA

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -14,15 +14,15 @@ example_indexes = {
14
  "Uber 10k 2021": "./uber_index",
15
  }
16
 
17
- # device = "cpu"
18
- # if torch.cuda.is_available():
19
- # device = "cuda"
20
- # elif torch.backends.mps.is_available():
21
- # device = "mps"
22
 
23
  image_embed_model = HuggingFaceEmbedding(
24
  model_name="llamaindex/vdr-2b-multi-v1",
25
- device="cpu",
26
  trust_remote_code=True,
27
  token=os.getenv("HUGGINGFACE_TOKEN"),
28
  model_kwargs={"torch_dtype": torch.float16},
@@ -31,10 +31,10 @@ image_embed_model = HuggingFaceEmbedding(
31
 
32
  text_embed_model = HuggingFaceEmbedding(
33
  model_name="BAAI/bge-small-en",
34
- device="cpu",
35
  trust_remote_code=True,
36
  token=os.getenv("HUGGINGFACE_TOKEN"),
37
- embed_batch_size=2,
38
  )
39
 
40
  def load_index(index_path: str) -> MultiModalVectorStoreIndex:
 
14
  "Uber 10k 2021": "./uber_index",
15
  }
16
 
17
+ device = "cpu"
18
+ if torch.cuda.is_available():
19
+ device = "cuda"
20
+ elif torch.backends.mps.is_available():
21
+ device = "mps"
22
 
23
  image_embed_model = HuggingFaceEmbedding(
24
  model_name="llamaindex/vdr-2b-multi-v1",
25
+ device=device,
26
  trust_remote_code=True,
27
  token=os.getenv("HUGGINGFACE_TOKEN"),
28
  model_kwargs={"torch_dtype": torch.float16},
 
31
 
32
  text_embed_model = HuggingFaceEmbedding(
33
  model_name="BAAI/bge-small-en",
34
+ device=device,
35
  trust_remote_code=True,
36
  token=os.getenv("HUGGINGFACE_TOKEN"),
37
+ embed_batch_size=1,
38
  )
39
 
40
  def load_index(index_path: str) -> MultiModalVectorStoreIndex: