Update modeling_qwen.py
Browse files- modeling_qwen.py +1 -1
modeling_qwen.py
CHANGED
@@ -556,7 +556,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
556 |
output_hidden_states: Optional[bool] = None,
|
557 |
return_dict: Optional[bool] = None,
|
558 |
):
|
559 |
-
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
|
560 |
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
|
561 |
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
|
562 |
assert (bos_pos[0] == eos_pos[0]).all()
|
|
|
556 |
output_hidden_states: Optional[bool] = None,
|
557 |
return_dict: Optional[bool] = None,
|
558 |
):
|
559 |
+
if past_key_values is None and input_ids is not None and torch.any(input_ids == self.config.visual['image_start_id']):
|
560 |
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
|
561 |
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
|
562 |
assert (bos_pos[0] == eos_pos[0]).all()
|