THUdyh commited on
Commit
44b12c2
1 Parent(s): fe96dab

Update oryx/model/builder.py

Browse files
Files changed (1) hide show
  1. oryx/model/builder.py +5 -5
oryx/model/builder.py CHANGED
@@ -75,11 +75,11 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
75
  print("Loading vision tower...")
76
  if not vision_tower.is_loaded:
77
  vision_tower.load_model(device_map=device_map)
78
- if device_map != "auto":
79
- vision_tower = vision_tower.bfloat16()
80
- vision_tower = vision_tower.to("cuda")
81
- else:
82
- vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
83
  image_processor = vision_tower.image_processor
84
  print("Loading vision tower succeeded.")
85
  if hasattr(model.config, "max_sequence_length"):
 
75
  print("Loading vision tower...")
76
  if not vision_tower.is_loaded:
77
  vision_tower.load_model(device_map=device_map)
78
+ # if device_map != "auto":
79
+ # vision_tower = vision_tower.bfloat16()
80
+ # vision_tower = vision_tower.to("cuda")
81
+ # else:
82
+ # vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
83
  image_processor = vision_tower.image_processor
84
  print("Loading vision tower succeeded.")
85
  if hasattr(model.config, "max_sequence_length"):