erwold commited on
Commit
7ffc337
·
1 Parent(s): 9590121
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -82,7 +82,15 @@ qwen2vl_processor = AutoProcessor.from_pretrained(
82
  )
83
 
84
  # 加载 connector 和 embedder 到 CPU
85
- connector = nn.Linear(3584, 4096).to(dtype).cpu()
 
 
 
 
 
 
 
 
86
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
87
  connector_state = torch.load(connector_path, map_location='cpu')
88
  connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()}
 
82
  )
83
 
84
  # 加载 connector 和 embedder 到 CPU
85
+ class Qwen2Connector(nn.Module):
86
+ def __init__(self, input_dim=3584, output_dim=4096):
87
+ super().__init__()
88
+ self.linear = nn.Linear(input_dim, output_dim)
89
+
90
+ def forward(self, x):
91
+ return self.linear(x)
92
+
93
+ connector = Qwen2Connector().to(dtype).cpu()
94
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
95
  connector_state = torch.load(connector_path, map_location='cpu')
96
  connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()}