kunato commited on
Commit
51ee87b
·
verified ·
1 Parent(s): 47d85a6

Update modeling_typhoon2audio.py

Browse files
Files changed (1) hide show
  1. modeling_typhoon2audio.py +5 -3
modeling_typhoon2audio.py CHANGED
@@ -18,6 +18,7 @@ from transformers import (
18
  WhisperModel,
19
  PreTrainedModel,
20
  AutoTokenizer,
 
21
  AutoModelForCausalLM,
22
  )
23
  from transformers.cache_utils import Cache, StaticCache
@@ -63,6 +64,7 @@ from transformers.modeling_utils import (
63
  apply_chunking_to_forward,
64
  find_pruneable_heads_and_indices,
65
  prune_linear_layer,
 
66
  )
67
  from transformers.models.bert.configuration_bert import BertConfig
68
 
@@ -841,9 +843,9 @@ class Typhoon2AudioForConditionalGeneration(PreTrainedModel, GenerationMixin):
841
  self.second_stride = config.second_stride
842
 
843
  # 2. LLM (e.g., Llama3)
844
- self.llama_model = AutoModelForCausalLM.from_pretrained(
845
- config.llama_base_model, attn_implementation=attn_implementation
846
- )
847
  # tokenizer
848
  self.llama_tokenizer = AutoTokenizer.from_pretrained(
849
  config.llama_base_model, use_fast=False
 
18
  WhisperModel,
19
  PreTrainedModel,
20
  AutoTokenizer,
21
+ AutoConfig,
22
  AutoModelForCausalLM,
23
  )
24
  from transformers.cache_utils import Cache, StaticCache
 
64
  apply_chunking_to_forward,
65
  find_pruneable_heads_and_indices,
66
  prune_linear_layer,
67
+ no_init_weights
68
  )
69
  from transformers.models.bert.configuration_bert import BertConfig
70
 
 
843
  self.second_stride = config.second_stride
844
 
845
  # 2. LLM (e.g., Llama3)
846
+ with no_init_weights(_enable=True):
847
+ llm_config = AutoConfig.from_pretrained(config.llama_base_model)
848
+ self.llama_model = AutoModelForCausalLM.from_config(llm_config, attn_implementation=attn_implementation)
849
  # tokenizer
850
  self.llama_tokenizer = AutoTokenizer.from_pretrained(
851
  config.llama_base_model, use_fast=False