THUdyh commited on
Commit
fe96dab
1 Parent(s): 35e6890

Update oryx/model/builder.py

Browse files
Files changed (1) hide show
  1. oryx/model/builder.py +2 -1
oryx/model/builder.py CHANGED
@@ -76,7 +76,8 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
76
  if not vision_tower.is_loaded:
77
  vision_tower.load_model(device_map=device_map)
78
  if device_map != "auto":
79
- vision_tower.to(device="cuda", dtype=torch.bfloat16)
 
80
  else:
81
  vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
82
  image_processor = vision_tower.image_processor
 
76
  if not vision_tower.is_loaded:
77
  vision_tower.load_model(device_map=device_map)
78
  if device_map != "auto":
79
+ vision_tower = vision_tower.bfloat16()
80
+ vision_tower = vision_tower.to("cuda")
81
  else:
82
  vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
83
  image_processor = vision_tower.image_processor