sanchit-gandhi commited on
Commit
6096f07
·
verified ·
1 Parent(s): 3cbe542

Saving train state of step 120000

Browse files
checkpoint-120000-epoch-8/config.json ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ParlerTTSForConditionalGeneration"
4
+ ],
5
+ "audio_encoder": {
6
+ "_name_or_path": "parler-tts/dac_44khZ_8kbps",
7
+ "add_cross_attention": false,
8
+ "architectures": [
9
+ "DACModel"
10
+ ],
11
+ "bad_words_ids": null,
12
+ "begin_suppress_tokens": null,
13
+ "bos_token_id": null,
14
+ "chunk_size_feed_forward": 0,
15
+ "codebook_size": 1024,
16
+ "cross_attention_hidden_size": null,
17
+ "decoder_start_token_id": null,
18
+ "diversity_penalty": 0.0,
19
+ "do_sample": false,
20
+ "early_stopping": false,
21
+ "encoder_no_repeat_ngram_size": 0,
22
+ "eos_token_id": null,
23
+ "exponential_decay_length_penalty": null,
24
+ "finetuning_task": null,
25
+ "forced_bos_token_id": null,
26
+ "forced_eos_token_id": null,
27
+ "frame_rate": 86,
28
+ "id2label": {
29
+ "0": "LABEL_0",
30
+ "1": "LABEL_1"
31
+ },
32
+ "is_decoder": false,
33
+ "is_encoder_decoder": false,
34
+ "label2id": {
35
+ "LABEL_0": 0,
36
+ "LABEL_1": 1
37
+ },
38
+ "latent_dim": 1024,
39
+ "length_penalty": 1.0,
40
+ "max_length": 20,
41
+ "min_length": 0,
42
+ "model_bitrate": 8,
43
+ "model_type": "dac",
44
+ "no_repeat_ngram_size": 0,
45
+ "num_beam_groups": 1,
46
+ "num_beams": 1,
47
+ "num_codebooks": 9,
48
+ "num_return_sequences": 1,
49
+ "output_attentions": false,
50
+ "output_hidden_states": false,
51
+ "output_scores": false,
52
+ "pad_token_id": null,
53
+ "prefix": null,
54
+ "problem_type": null,
55
+ "pruned_heads": {},
56
+ "remove_invalid_values": false,
57
+ "repetition_penalty": 1.0,
58
+ "return_dict": true,
59
+ "return_dict_in_generate": false,
60
+ "sampling_rate": 44100,
61
+ "sep_token_id": null,
62
+ "suppress_tokens": null,
63
+ "task_specific_params": null,
64
+ "temperature": 1.0,
65
+ "tf_legacy_loss": false,
66
+ "tie_encoder_decoder": false,
67
+ "tie_word_embeddings": true,
68
+ "tokenizer_class": null,
69
+ "top_k": 50,
70
+ "top_p": 1.0,
71
+ "torch_dtype": "float32",
72
+ "torchscript": false,
73
+ "typical_p": 1.0,
74
+ "use_bfloat16": false
75
+ },
76
+ "decoder": {
77
+ "_name_or_path": "./parler-tts-untrained-600M/decoder",
78
+ "activation_dropout": 0.0,
79
+ "activation_function": "gelu",
80
+ "add_cross_attention": true,
81
+ "architectures": [
82
+ "ParlerTTSForCausalLM"
83
+ ],
84
+ "attention_dropout": 0.0,
85
+ "bad_words_ids": null,
86
+ "begin_suppress_tokens": null,
87
+ "bos_token_id": 1025,
88
+ "chunk_size_feed_forward": 0,
89
+ "cross_attention_hidden_size": null,
90
+ "decoder_start_token_id": null,
91
+ "diversity_penalty": 0.0,
92
+ "do_sample": false,
93
+ "dropout": 0.1,
94
+ "early_stopping": false,
95
+ "encoder_no_repeat_ngram_size": 0,
96
+ "eos_token_id": 1024,
97
+ "exponential_decay_length_penalty": null,
98
+ "ffn_dim": 4096,
99
+ "finetuning_task": null,
100
+ "forced_bos_token_id": null,
101
+ "forced_eos_token_id": null,
102
+ "hidden_size": 1024,
103
+ "id2label": {
104
+ "0": "LABEL_0",
105
+ "1": "LABEL_1"
106
+ },
107
+ "initializer_factor": 0.02,
108
+ "is_decoder": true,
109
+ "is_encoder_decoder": false,
110
+ "label2id": {
111
+ "LABEL_0": 0,
112
+ "LABEL_1": 1
113
+ },
114
+ "layerdrop": 0.0,
115
+ "length_penalty": 1.0,
116
+ "max_length": 20,
117
+ "max_position_embeddings": 4096,
118
+ "min_length": 0,
119
+ "model_type": "parler_tts_decoder",
120
+ "no_repeat_ngram_size": 0,
121
+ "num_attention_heads": 16,
122
+ "num_beam_groups": 1,
123
+ "num_beams": 1,
124
+ "num_codebooks": 9,
125
+ "num_hidden_layers": 24,
126
+ "num_return_sequences": 1,
127
+ "output_attentions": false,
128
+ "output_hidden_states": false,
129
+ "output_scores": false,
130
+ "pad_token_id": 1024,
131
+ "prefix": null,
132
+ "problem_type": null,
133
+ "pruned_heads": {},
134
+ "remove_invalid_values": false,
135
+ "repetition_penalty": 1.0,
136
+ "return_dict": true,
137
+ "return_dict_in_generate": false,
138
+ "rope_embeddings": false,
139
+ "rope_theta": 10000.0,
140
+ "scale_embedding": false,
141
+ "sep_token_id": null,
142
+ "suppress_tokens": null,
143
+ "task_specific_params": null,
144
+ "temperature": 1.0,
145
+ "tf_legacy_loss": false,
146
+ "tie_encoder_decoder": false,
147
+ "tie_word_embeddings": false,
148
+ "tokenizer_class": null,
149
+ "top_k": 50,
150
+ "top_p": 1.0,
151
+ "torch_dtype": "float32",
152
+ "torchscript": false,
153
+ "typical_p": 1.0,
154
+ "use_bfloat16": false,
155
+ "use_cache": true,
156
+ "vocab_size": 1088
157
+ },
158
+ "decoder_start_token_id": 1025,
159
+ "is_encoder_decoder": true,
160
+ "model_type": "parler_tts",
161
+ "pad_token_id": 1024,
162
+ "prompt_cross_attention": true,
163
+ "text_encoder": {
164
+ "_name_or_path": "google/flan-t5-base",
165
+ "add_cross_attention": false,
166
+ "architectures": [
167
+ "T5ForConditionalGeneration"
168
+ ],
169
+ "bad_words_ids": null,
170
+ "begin_suppress_tokens": null,
171
+ "bos_token_id": null,
172
+ "chunk_size_feed_forward": 0,
173
+ "classifier_dropout": 0.0,
174
+ "cross_attention_hidden_size": null,
175
+ "d_ff": 2048,
176
+ "d_kv": 64,
177
+ "d_model": 768,
178
+ "decoder_start_token_id": 0,
179
+ "dense_act_fn": "gelu_new",
180
+ "diversity_penalty": 0.0,
181
+ "do_sample": false,
182
+ "dropout_rate": 0.1,
183
+ "early_stopping": false,
184
+ "encoder_no_repeat_ngram_size": 0,
185
+ "eos_token_id": 1,
186
+ "exponential_decay_length_penalty": null,
187
+ "feed_forward_proj": "gated-gelu",
188
+ "finetuning_task": null,
189
+ "forced_bos_token_id": null,
190
+ "forced_eos_token_id": null,
191
+ "id2label": {
192
+ "0": "LABEL_0",
193
+ "1": "LABEL_1"
194
+ },
195
+ "initializer_factor": 1.0,
196
+ "is_decoder": false,
197
+ "is_encoder_decoder": true,
198
+ "is_gated_act": true,
199
+ "label2id": {
200
+ "LABEL_0": 0,
201
+ "LABEL_1": 1
202
+ },
203
+ "layer_norm_epsilon": 1e-06,
204
+ "length_penalty": 1.0,
205
+ "max_length": 20,
206
+ "min_length": 0,
207
+ "model_type": "t5",
208
+ "n_positions": 512,
209
+ "no_repeat_ngram_size": 0,
210
+ "num_beam_groups": 1,
211
+ "num_beams": 1,
212
+ "num_decoder_layers": 12,
213
+ "num_heads": 12,
214
+ "num_layers": 12,
215
+ "num_return_sequences": 1,
216
+ "output_attentions": false,
217
+ "output_hidden_states": false,
218
+ "output_past": true,
219
+ "output_scores": false,
220
+ "pad_token_id": 0,
221
+ "prefix": null,
222
+ "problem_type": null,
223
+ "pruned_heads": {},
224
+ "relative_attention_max_distance": 128,
225
+ "relative_attention_num_buckets": 32,
226
+ "remove_invalid_values": false,
227
+ "repetition_penalty": 1.0,
228
+ "return_dict": true,
229
+ "return_dict_in_generate": false,
230
+ "sep_token_id": null,
231
+ "suppress_tokens": null,
232
+ "task_specific_params": {
233
+ "summarization": {
234
+ "early_stopping": true,
235
+ "length_penalty": 2.0,
236
+ "max_length": 200,
237
+ "min_length": 30,
238
+ "no_repeat_ngram_size": 3,
239
+ "num_beams": 4,
240
+ "prefix": "summarize: "
241
+ },
242
+ "translation_en_to_de": {
243
+ "early_stopping": true,
244
+ "max_length": 300,
245
+ "num_beams": 4,
246
+ "prefix": "translate English to German: "
247
+ },
248
+ "translation_en_to_fr": {
249
+ "early_stopping": true,
250
+ "max_length": 300,
251
+ "num_beams": 4,
252
+ "prefix": "translate English to French: "
253
+ },
254
+ "translation_en_to_ro": {
255
+ "early_stopping": true,
256
+ "max_length": 300,
257
+ "num_beams": 4,
258
+ "prefix": "translate English to Romanian: "
259
+ }
260
+ },
261
+ "temperature": 1.0,
262
+ "tf_legacy_loss": false,
263
+ "tie_encoder_decoder": false,
264
+ "tie_word_embeddings": false,
265
+ "tokenizer_class": null,
266
+ "top_k": 50,
267
+ "top_p": 1.0,
268
+ "torch_dtype": null,
269
+ "torchscript": false,
270
+ "typical_p": 1.0,
271
+ "use_bfloat16": false,
272
+ "use_cache": true,
273
+ "vocab_size": 32128
274
+ },
275
+ "torch_dtype": "float32",
276
+ "transformers_version": "4.40.2",
277
+ "vocab_size": 32128
278
+ }
checkpoint-120000-epoch-8/generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1025,
4
+ "decoder_start_token_id": 1025,
5
+ "do_sample": true,
6
+ "eos_token_id": 1024,
7
+ "guidance_scale": 1,
8
+ "key": 10,
9
+ "max_length": 2580,
10
+ "pad_token_id": 1024,
11
+ "transformers_version": "4.40.2"
12
+ }
checkpoint-120000-epoch-8/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:733b094e78f727fce8a0183cdcfc72f5e2b154ed2934959c320e2465693c9577
3
+ size 3652769047
checkpoint-120000-epoch-8/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce7533373536c757b73a74a2ba6185974c2d35e466e38f37a6781da4031a98e8
3
+ size 2605239710
checkpoint-120000-epoch-8/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a23edacfa0329be5ccd2f08651b598e5c8cd31d1fd8e33a3ed02c60c8a3654a6
3
+ size 16036
checkpoint-120000-epoch-8/random_states_1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53c4ef324c16e8b4466ddbd1b26ea5df0eab4d6fe281391be521efad9f1c87f3
3
+ size 16100
checkpoint-120000-epoch-8/random_states_2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44f53b0c4745162bed7c7899526b263264a4d5d5b4f4e2c1bc4b6af765c4b6e1
3
+ size 16100
checkpoint-120000-epoch-8/random_states_3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e432fe468eb01fa59dcd2d3c7c8969f260b157c4395d35773b847f11a37b12b
3
+ size 16100
checkpoint-120000-epoch-8/random_states_4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c41036f9a24b7f68c02bf44c037565e222ee8c40516670494732a1645fabb39
3
+ size 16100
checkpoint-120000-epoch-8/random_states_5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ee6a0009db3cca6c35019eabafa38b2a46def11da55b2c5e140f70974e8ae50
3
+ size 16100
checkpoint-120000-epoch-8/random_states_6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:525f5729c247c1493b3cf5283900401c86948a7eae7b41979b3784e3efa6b2bb
3
+ size 16100
checkpoint-120000-epoch-8/random_states_7.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b31f40c509348d672e052602f82d0d37f85f9dcadbc107dc7217a8e3e6f3092
3
+ size 16036
checkpoint-120000-epoch-8/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a46a1b5d12218eb49696470fb7337ce7c2ac2f6cb2f18a57ba0f1af48738171
3
+ size 1000
starting_point_0.01.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "model_name_or_path": "parler-tts/parler-tts-untrained-600M-cross-attention",
3
- "save_to_disk": "/fsx/yoach/tmp/artefacts/10k_hours_processed_punctuated/",
4
  "temporary_save_to_disk": "/scratch/tmp_dataset_audio/",
5
  "push_to_hub": true,
6
 
@@ -10,6 +10,7 @@
10
  "prompt_tokenizer_name":"google/flan-t5-base",
11
 
12
  "report_to": ["wandb"],
 
13
  "overwrite_output_dir": false,
14
  "output_dir": "./",
15
 
 
1
  {
2
  "model_name_or_path": "parler-tts/parler-tts-untrained-600M-cross-attention",
3
+ "save_to_disk": "/fsx/sanchit/10k_hours_processed_punctuated",
4
  "temporary_save_to_disk": "/scratch/tmp_dataset_audio/",
5
  "push_to_hub": true,
6
 
 
10
  "prompt_tokenizer_name":"google/flan-t5-base",
11
 
12
  "report_to": ["wandb"],
13
+ "wandb_run_name": "parler-tts-600M-cross-attention",
14
  "overwrite_output_dir": false,
15
  "output_dir": "./",
16
 
training/__pycache__/arguments.cpython-311.pyc CHANGED
Binary files a/training/__pycache__/arguments.cpython-311.pyc and b/training/__pycache__/arguments.cpython-311.pyc differ
 
training/__pycache__/data.cpython-311.pyc CHANGED
Binary files a/training/__pycache__/data.cpython-311.pyc and b/training/__pycache__/data.cpython-311.pyc differ
 
training/__pycache__/eval.cpython-311.pyc CHANGED
Binary files a/training/__pycache__/eval.cpython-311.pyc and b/training/__pycache__/eval.cpython-311.pyc differ
 
training/__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/training/__pycache__/utils.cpython-311.pyc and b/training/__pycache__/utils.cpython-311.pyc differ
 
training/arguments.py CHANGED
@@ -218,7 +218,7 @@ class DataTrainingArguments:
218
  metadata={
219
  "help": (
220
  "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
221
- "Also, used to set maximum desription token length if `pad_to_max_length=True`."
222
  )
223
  },
224
  )
@@ -277,6 +277,12 @@ class DataTrainingArguments:
277
  default="parler-speech",
278
  metadata={"help": "The name of the wandb project."},
279
  )
 
 
 
 
 
 
280
  save_to_disk: str = field(
281
  default=None,
282
  metadata={
 
218
  metadata={
219
  "help": (
220
  "If set, filter samples with descriptions that are longer than `max_description_token_length` tokens."
221
+ "Also, used to set maximum description token length if `pad_to_max_length=True`."
222
  )
223
  },
224
  )
 
277
  default="parler-speech",
278
  metadata={"help": "The name of the wandb project."},
279
  )
280
+ wandb_run_name: str = field(
281
+ default=None,
282
+ metadata={
283
+ "help": "If specified, the name of the run. If not specified, wandb will give a random name to this run."
284
+ },
285
+ )
286
  save_to_disk: str = field(
287
  default=None,
288
  metadata={
training/data.py CHANGED
@@ -31,7 +31,12 @@ class DataCollatorEncodecWithPadding:
31
  audios = [feature[self.audio_column_name]["array"] for feature in features]
32
  len_audio = [len(audio) for audio in audios]
33
 
34
- batch = self.feature_extractor(audios, return_tensors="pt", padding=self.padding, max_length=self.max_length)
 
 
 
 
 
35
  batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
36
  return batch
37
 
 
31
  audios = [feature[self.audio_column_name]["array"] for feature in features]
32
  len_audio = [len(audio) for audio in audios]
33
 
34
+ # since resampling has already been performed in the 'load_multiple_datasets' function,
35
+ # a fixed sampling_rate(44100hz) is passed to the feature_extractor.
36
+ sampling_rate = self.feature_extractor.sampling_rate
37
+ batch = self.feature_extractor(
38
+ audios, sampling_rate=sampling_rate, return_tensors="pt", padding=self.padding, max_length=self.max_length
39
+ )
40
  batch["len_audio"] = torch.tensor(len_audio).unsqueeze(1)
41
  return batch
42
 
training/eval.py CHANGED
@@ -47,8 +47,7 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
47
  normalized_references = []
48
 
49
  for pred, ref in zip(transcriptions, prompts):
50
- normalizer = english_normalizer
51
-
52
  norm_ref = normalizer(ref)
53
  if len(norm_ref) > 0:
54
  norm_pred = normalizer(pred["text"])
 
47
  normalized_references = []
48
 
49
  for pred, ref in zip(transcriptions, prompts):
50
+ normalizer = english_normalizer if return_language and pred["chunks"][0]["language"] == "english" else basic_normalizer
 
51
  norm_ref = normalizer(ref)
52
  if len(norm_ref) > 0:
53
  norm_pred = normalizer(pred["text"])
training/run_parler_tts_training.py CHANGED
@@ -98,9 +98,6 @@ def main():
98
 
99
  ####### A. Preparation
100
  kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
101
- if training_args.torch_compile:
102
- # TODO(YL): add more compile modes?
103
- kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default")) # reduce-overhead
104
 
105
  accelerator = Accelerator(
106
  gradient_accumulation_steps=training_args.gradient_accumulation_steps,
@@ -129,6 +126,7 @@ def main():
129
  "adam_beta2": training_args.adam_beta2,
130
  "temperature": model_args.temperature,
131
  },
 
132
  )
133
 
134
  # Detecting last checkpoint and eventually continue from last checkpoint
@@ -136,7 +134,7 @@ def main():
136
  if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
137
  last_checkpoint = get_last_checkpoint(training_args.output_dir)
138
  if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
139
- raise ValueError(
140
  f"Output directory ({training_args.output_dir}) already exists and is not empty. "
141
  "Use --overwrite_output_dir to overcome."
142
  )
@@ -314,6 +312,7 @@ def main():
314
  token=data_args.token,
315
  trust_remote_code=data_args.trust_remote_code,
316
  )
 
317
 
318
  # enable gradient checkpointing if necessary
319
  if training_args.gradient_checkpointing:
@@ -334,8 +333,8 @@ def main():
334
  feature_extractor_input_name = feature_extractor.model_input_names[0]
335
  audio_encoder_pad_token_id = config.decoder.pad_token_id
336
  audio_encoder_eos_token_id = config.decoder.eos_token_id
337
- audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
338
- max_length = model.generation_config.max_length
339
  num_codebooks = model.decoder.config.num_codebooks
340
  bandwidth = model_args.bandwidth
341
 
@@ -538,7 +537,7 @@ def main():
538
  logger.info(f"Dataset saved at {data_args.save_to_disk}")
539
 
540
  audio_max_length = None
541
- if training_args.torch_compile:
542
  audio_max_length = max(vectorized_datasets["train"]["target_length"])
543
  with accelerator.main_process_first():
544
  max_sample = vectorized_datasets["train"].filter(
@@ -548,6 +547,18 @@ def main():
548
  )
549
  audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
550
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  # for large datasets it is advised to run the preprocessing on a
552
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
553
  # be a timeout when running the script in distributed mode.
@@ -670,6 +681,8 @@ def main():
670
  checkpoint = last_checkpoint
671
 
672
  if accelerator.is_main_process:
 
 
673
  if training_args.push_to_hub:
674
  api = HfApi(token=training_args.hub_token)
675
 
@@ -682,8 +695,6 @@ def main():
682
  with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
683
  if "wandb" not in gitignore:
684
  gitignore.write("wandb\n")
685
- elif training_args.output_dir is not None:
686
- os.makedirs(training_args.output_dir, exist_ok=True)
687
  accelerator.wait_for_everyone()
688
 
689
  # Now save everything to be able to create a single processor later
@@ -740,7 +751,13 @@ def main():
740
  "do_sample": model_args.do_sample,
741
  "temperature": model_args.temperature,
742
  "max_length": model_args.max_length,
 
 
 
 
743
  }
 
 
744
 
745
  # Define gradient update step fn
746
  def train_step(
@@ -869,9 +886,11 @@ def main():
869
  # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
870
  # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
871
  accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
 
 
872
  accelerator.wait_for_everyone()
873
  if accelerator.is_main_process:
874
- rotate_checkpoints(
875
  training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
876
  )
877
 
@@ -886,6 +905,7 @@ def main():
886
  folder_path=training_args.output_dir,
887
  commit_message=f"Saving train state of step {cur_step}",
888
  run_as_future=True,
 
889
  )
890
 
891
  if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
 
98
 
99
  ####### A. Preparation
100
  kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
 
 
 
101
 
102
  accelerator = Accelerator(
103
  gradient_accumulation_steps=training_args.gradient_accumulation_steps,
 
126
  "adam_beta2": training_args.adam_beta2,
127
  "temperature": model_args.temperature,
128
  },
129
+ init_kwargs={"wandb": {"name": data_args.wandb_run_name}} if data_args.wandb_run_name else {},
130
  )
131
 
132
  # Detecting last checkpoint and eventually continue from last checkpoint
 
134
  if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
135
  last_checkpoint = get_last_checkpoint(training_args.output_dir)
136
  if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
137
+ logger.info(
138
  f"Output directory ({training_args.output_dir}) already exists and is not empty. "
139
  "Use --overwrite_output_dir to overcome."
140
  )
 
312
  token=data_args.token,
313
  trust_remote_code=data_args.trust_remote_code,
314
  )
315
+ generation_config = model.generation_config
316
 
317
  # enable gradient checkpointing if necessary
318
  if training_args.gradient_checkpointing:
 
333
  feature_extractor_input_name = feature_extractor.model_input_names[0]
334
  audio_encoder_pad_token_id = config.decoder.pad_token_id
335
  audio_encoder_eos_token_id = config.decoder.eos_token_id
336
+ audio_encoder_bos_token_id = generation_config.decoder_start_token_id
337
+ max_length = generation_config.max_length
338
  num_codebooks = model.decoder.config.num_codebooks
339
  bandwidth = model_args.bandwidth
340
 
 
537
  logger.info(f"Dataset saved at {data_args.save_to_disk}")
538
 
539
  audio_max_length = None
540
+ if padding == "max_length":
541
  audio_max_length = max(vectorized_datasets["train"]["target_length"])
542
  with accelerator.main_process_first():
543
  max_sample = vectorized_datasets["train"].filter(
 
547
  )
548
  audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
549
 
550
+ if training_args.group_by_length:
551
+ # apply a simple heuristic to take into account audio and text lengths
552
+ def add_target_lengths(target_length, prompt, description):
553
+ return {"target_length": target_length + len(prompt) + len(description)}
554
+
555
+ with accelerator.main_process_first():
556
+ vectorized_datasets = vectorized_datasets.map(
557
+ add_target_lengths,
558
+ num_proc=num_workers,
559
+ input_columns=["target_length", "prompt_input_ids", "input_ids"],
560
+ )
561
+
562
  # for large datasets it is advised to run the preprocessing on a
563
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
564
  # be a timeout when running the script in distributed mode.
 
681
  checkpoint = last_checkpoint
682
 
683
  if accelerator.is_main_process:
684
+ if training_args.output_dir is not None:
685
+ os.makedirs(training_args.output_dir, exist_ok=True)
686
  if training_args.push_to_hub:
687
  api = HfApi(token=training_args.hub_token)
688
 
 
695
  with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
696
  if "wandb" not in gitignore:
697
  gitignore.write("wandb\n")
 
 
698
  accelerator.wait_for_everyone()
699
 
700
  # Now save everything to be able to create a single processor later
 
751
  "do_sample": model_args.do_sample,
752
  "temperature": model_args.temperature,
753
  "max_length": model_args.max_length,
754
+ # Because of the delayed pattern mask, generation might stop earlier because of unexpected behaviour
755
+ # on the first tokens of the codebooks that are delayed.
756
+ # This fix the issue.
757
+ "min_new_tokens": num_codebooks + 1,
758
  }
759
+ for key in gen_kwargs:
760
+ generation_config.key = gen_kwargs[key]
761
 
762
  # Define gradient update step fn
763
  def train_step(
 
886
  # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
887
  # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
888
  accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
889
+ config.save_pretrained(intermediate_dir)
890
+ generation_config.save_pretrained(intermediate_dir)
891
  accelerator.wait_for_everyone()
892
  if accelerator.is_main_process:
893
+ checkpoints_to_be_deleted = rotate_checkpoints(
894
  training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
895
  )
896
 
 
905
  folder_path=training_args.output_dir,
906
  commit_message=f"Saving train state of step {cur_step}",
907
  run_as_future=True,
908
+ delete_patterns=checkpoints_to_be_deleted,
909
  )
910
 
911
  if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
training/utils.py CHANGED
@@ -3,7 +3,7 @@ import re
3
  import shutil
4
  from pathlib import Path
5
  from dataclasses import field
6
- from typing import Dict, List
7
 
8
  import torch
9
  from wandb import Audio
@@ -44,7 +44,7 @@ def sorted_checkpoints(output_dir=None, checkpoint_prefix="checkpoint") -> List[
44
  return checkpoints_sorted
45
 
46
 
47
- def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> None:
48
  """Helper function to delete old checkpoints."""
49
  if save_total_limit is None or save_total_limit <= 0:
50
  return
@@ -58,6 +58,8 @@ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix
58
  for checkpoint in checkpoints_to_be_deleted:
59
  logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
60
  shutil.rmtree(checkpoint, ignore_errors=True)
 
 
61
 
62
 
63
  def log_metric(
 
3
  import shutil
4
  from pathlib import Path
5
  from dataclasses import field
6
+ from typing import Dict, List, Union
7
 
8
  import torch
9
  from wandb import Audio
 
44
  return checkpoints_sorted
45
 
46
 
47
+ def rotate_checkpoints(save_total_limit=None, output_dir=None, checkpoint_prefix="checkpoint", logger=None) -> Union[List, None]:
48
  """Helper function to delete old checkpoints."""
49
  if save_total_limit is None or save_total_limit <= 0:
50
  return
 
58
  for checkpoint in checkpoints_to_be_deleted:
59
  logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
60
  shutil.rmtree(checkpoint, ignore_errors=True)
61
+ checkpoints_to_be_deleted = [f"*{Path(checkpoint).absolute().name}*" for checkpoint in checkpoints_to_be_deleted]
62
+ return checkpoints_to_be_deleted
63
 
64
 
65
  def log_metric(