Update modeling_typhoon2audio.py
Browse files
modeling_typhoon2audio.py
CHANGED
@@ -18,6 +18,7 @@ from transformers import (
|
|
18 |
WhisperModel,
|
19 |
PreTrainedModel,
|
20 |
AutoTokenizer,
|
|
|
21 |
AutoModelForCausalLM,
|
22 |
)
|
23 |
from transformers.cache_utils import Cache, StaticCache
|
@@ -63,6 +64,7 @@ from transformers.modeling_utils import (
|
|
63 |
apply_chunking_to_forward,
|
64 |
find_pruneable_heads_and_indices,
|
65 |
prune_linear_layer,
|
|
|
66 |
)
|
67 |
from transformers.models.bert.configuration_bert import BertConfig
|
68 |
|
@@ -841,9 +843,9 @@ class Typhoon2AudioForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|
841 |
self.second_stride = config.second_stride
|
842 |
|
843 |
# 2. LLM (e.g., Llama3)
|
844 |
-
|
845 |
-
config.llama_base_model
|
846 |
-
|
847 |
# tokenizer
|
848 |
self.llama_tokenizer = AutoTokenizer.from_pretrained(
|
849 |
config.llama_base_model, use_fast=False
|
|
|
18 |
WhisperModel,
|
19 |
PreTrainedModel,
|
20 |
AutoTokenizer,
|
21 |
+
AutoConfig,
|
22 |
AutoModelForCausalLM,
|
23 |
)
|
24 |
from transformers.cache_utils import Cache, StaticCache
|
|
|
64 |
apply_chunking_to_forward,
|
65 |
find_pruneable_heads_and_indices,
|
66 |
prune_linear_layer,
|
67 |
+
no_init_weights
|
68 |
)
|
69 |
from transformers.models.bert.configuration_bert import BertConfig
|
70 |
|
|
|
843 |
self.second_stride = config.second_stride
|
844 |
|
845 |
# 2. LLM (e.g., Llama3)
|
846 |
+
with no_init_weights(_enable=True):
|
847 |
+
llm_config = AutoConfig.from_pretrained(config.llama_base_model)
|
848 |
+
self.llama_model = AutoModelForCausalLM.from_config(llm_config, attn_implementation=attn_implementation)
|
849 |
# tokenizer
|
850 |
self.llama_tokenizer = AutoTokenizer.from_pretrained(
|
851 |
config.llama_base_model, use_fast=False
|