Replace the inplace operation (#7)
Browse files- Update modeling_minicpmo.py (909a86b1f20fd048c8a8fbe4119910812cc3eaaf)
Co-authored-by: Zhangchi Feng <[email protected]>
- modeling_minicpmo.py +10 -6
modeling_minicpmo.py
CHANGED
@@ -377,10 +377,12 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
377 |
else:
|
378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
379 |
|
|
|
|
|
380 |
vision_hidden_states = [
|
381 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
382 |
]
|
383 |
-
|
384 |
bs = len(data["input_ids"])
|
385 |
for i in range(bs):
|
386 |
cur_vs_hs = vision_hidden_states[i]
|
@@ -392,15 +394,16 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
392 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
393 |
).to(vllm_embedding.device)
|
394 |
|
395 |
-
cur_vllm_emb.
|
396 |
0,
|
397 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
398 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
399 |
)
|
|
|
400 |
elif self.training:
|
401 |
-
|
402 |
|
403 |
-
return
|
404 |
|
405 |
def get_audio_embedding_streaming(self, data):
|
406 |
r"""
|
@@ -595,7 +598,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
595 |
elif self.training:
|
596 |
for i in range(bs):
|
597 |
# dummy audio_embeddings
|
598 |
-
input_embeddings
|
599 |
|
600 |
return input_embeddings
|
601 |
|
@@ -751,7 +754,7 @@ class MiniCPMO(MiniCPMOPreTrainedModel):
|
|
751 |
input_ids=None,
|
752 |
pixel_values=None,
|
753 |
tgt_sizes=None,
|
754 |
-
audio_features=
|
755 |
audio_feature_lens=None,
|
756 |
image_bound=None,
|
757 |
audio_bounds=None,
|
@@ -2655,6 +2658,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|
2655 |
"""
|
2656 |
|
2657 |
config_class = ConditionalChatTTSConfig
|
|
|
2658 |
|
2659 |
def __init__(self, config: ConditionalChatTTSConfig):
|
2660 |
super().__init__(config)
|
|
|
377 |
else:
|
378 |
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
379 |
|
380 |
+
new_vllm_embedding = vllm_embedding.clone()
|
381 |
+
|
382 |
vision_hidden_states = [
|
383 |
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
384 |
]
|
385 |
+
|
386 |
bs = len(data["input_ids"])
|
387 |
for i in range(bs):
|
388 |
cur_vs_hs = vision_hidden_states[i]
|
|
|
394 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
395 |
).to(vllm_embedding.device)
|
396 |
|
397 |
+
new_vllm_embedding[i] = cur_vllm_emb.scatter(
|
398 |
0,
|
399 |
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
400 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
401 |
)
|
402 |
+
|
403 |
elif self.training:
|
404 |
+
new_vllm_embedding[i] += cur_vs_hs[0].mean() * 0
|
405 |
|
406 |
+
return new_vllm_embedding, vision_hidden_states
|
407 |
|
408 |
def get_audio_embedding_streaming(self, data):
|
409 |
r"""
|
|
|
598 |
elif self.training:
|
599 |
for i in range(bs):
|
600 |
# dummy audio_embeddings
|
601 |
+
input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0
|
602 |
|
603 |
return input_embeddings
|
604 |
|
|
|
754 |
input_ids=None,
|
755 |
pixel_values=None,
|
756 |
tgt_sizes=None,
|
757 |
+
audio_features=[],
|
758 |
audio_feature_lens=None,
|
759 |
image_bound=None,
|
760 |
audio_bounds=None,
|
|
|
2658 |
"""
|
2659 |
|
2660 |
config_class = ConditionalChatTTSConfig
|
2661 |
+
_no_split_modules = []
|
2662 |
|
2663 |
def __init__(self, config: ConditionalChatTTSConfig):
|
2664 |
super().__init__(config)
|