ginipick commited on
Commit
6a9adba
ยท
verified ยท
1 Parent(s): a7b49e3

Update app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +175 -189
app-backup.py CHANGED
@@ -20,6 +20,31 @@ logging.basicConfig(
20
  ]
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def analyze_lyrics(lyrics, repeat_chorus=2):
24
  lines = [line.strip() for line in lyrics.split('\n') if line.strip()]
25
 
@@ -36,42 +61,55 @@ def analyze_lyrics(lyrics, repeat_chorus=2):
36
  'chorus': [],
37
  'bridge': []
38
  }
 
 
 
 
 
 
39
 
40
- for line in lines:
41
  lower_line = line.lower()
 
 
42
  if '[verse]' in lower_line:
 
 
43
  current_section = 'verse'
44
  sections['verse'] += 1
 
45
  continue
46
  elif '[chorus]' in lower_line:
 
 
47
  current_section = 'chorus'
48
  sections['chorus'] += 1
 
49
  continue
50
  elif '[bridge]' in lower_line:
 
 
51
  current_section = 'bridge'
52
  sections['bridge'] += 1
 
53
  continue
54
 
55
- # ํ˜„์žฌ ์„น์…˜์— ๋ผ์ธ ์ถ”๊ฐ€
56
- if current_section:
57
- section_lines[current_section].append(line)
58
 
59
- # ๋งŒ์•ฝ ์ฝ”๋Ÿฌ์Šค๊ฐ€ 1ํšŒ๋งŒ ์žˆ๊ณ , repeat_chorus > 1์ด๋ฉด ๋ฐ˜๋ณตํ•ด์„œ ๋ถ™์ด๊ธฐ
60
- # chorus ์„น์…˜ ์ „์ฒด ๋ธ”๋ก์„ ๋ณต์ œ
61
- if sections['chorus'] == 1 and repeat_chorus > 1:
62
- chorus_block = section_lines['chorus'][:]
63
  for _ in range(repeat_chorus - 1):
64
- section_lines['chorus'].extend(chorus_block)
65
-
66
- # ๋ผ์ธ ์ˆ˜ ์žฌ๊ณ„์‚ฐ
67
- new_total_lines = sum(len(section_lines[sec]) for sec in section_lines)
68
-
69
- return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), new_total_lines, {
70
- 'verse': len(section_lines['verse']),
71
- 'chorus': len(section_lines['chorus']),
72
- 'bridge': len(section_lines['bridge'])
73
- }
74
 
 
75
 
76
  def calculate_generation_params(lyrics):
77
  sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics)
@@ -83,30 +121,31 @@ def calculate_generation_params(lyrics):
83
  'bridge': 5 # bridge๋Š” ํ•œ ์ค„๋‹น 5์ดˆ
84
  }
85
 
86
- # ๊ฐ ์„น์…˜๋ณ„ ์˜ˆ์ƒ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
87
- section_durations = {
88
- 'verse': section_lines['verse'] * time_per_line['verse'],
89
- 'chorus': section_lines['chorus'] * time_per_line['chorus'],
90
- 'bridge': section_lines['bridge'] * time_per_line['bridge']
91
- }
92
 
93
- total_duration = sum(section_durations.values())
94
- total_duration = max(60, total_duration) # ์ตœ์†Œ 60์ดˆ
 
95
 
96
- # ํ† ํฐ ๊ณ„์‚ฐ (๋” ๋ณด์ˆ˜์ ์ธ ๊ฐ’ ์‚ฌ์šฉ)
97
- base_tokens = 3000 # ๊ธฐ๋ณธ ํ† ํฐ ์ˆ˜
98
- tokens_per_line = 200 # ์ค„๋‹น ํ† ํฐ ์ˆ˜
 
99
 
100
- total_tokens = base_tokens + (total_lines * tokens_per_line)
101
 
102
- # ์„น์…˜ ๊ธฐ๋ฐ˜ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ
103
  if sections['chorus'] > 0:
104
- num_segments = 3 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ
105
  else:
106
- num_segments = 2 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ 2๊ฐœ ์„ธ๊ทธ๋จผํŠธ
107
 
108
- # ํ† ํฐ ์ˆ˜ ์ œํ•œ
109
- max_tokens = min(8000, total_tokens) # ์ตœ๋Œ€ 8000 ํ† ํฐ์œผ๋กœ ์ œํ•œ
110
 
111
  return {
112
  'max_tokens': max_tokens,
@@ -118,43 +157,15 @@ def calculate_generation_params(lyrics):
118
  'has_chorus': sections['chorus'] > 0
119
  }
120
 
121
- def get_audio_duration(file_path):
122
- try:
123
- import librosa
124
- duration = librosa.get_duration(path=file_path)
125
- return duration
126
- except Exception as e:
127
- logging.error(f"Failed to get audio duration: {e}")
128
- return None
129
-
130
- # ์–ธ์–ด ๊ฐ์ง€ ๋ฐ ๋ชจ๋ธ ์„ ํƒ ํ•จ์ˆ˜
131
  def detect_and_select_model(text):
132
- if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text): # ํ•œ๊ธ€
133
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
134
- elif re.search(r'[\u4e00-\u9fff]', text): # ์ค‘๊ตญ์–ด
135
  return "m-a-p/YuE-s1-7B-anneal-zh-cot"
136
- elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text): # ์ผ๋ณธ์–ด
137
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
138
- else: # ์˜์–ด/๊ธฐํƒ€
139
- return "m-a-p/YuE-s1-7B-anneal-en-cot"
140
-
141
-
142
-
143
- # GPU ์„ค์ • ์ตœ์ ํ™”
144
- def optimize_gpu_settings():
145
- if torch.cuda.is_available():
146
- torch.backends.cuda.matmul.allow_tf32 = True
147
- torch.backends.cudnn.benchmark = True
148
- torch.backends.cudnn.deterministic = False
149
- torch.backends.cudnn.enabled = True
150
-
151
- torch.cuda.empty_cache()
152
- torch.cuda.set_device(0)
153
-
154
- logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
155
- logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
156
  else:
157
- logging.warning("GPU not available!")
158
 
159
  def install_flash_attn():
160
  try:
@@ -176,17 +187,13 @@ def install_flash_attn():
176
  except ImportError:
177
  logging.info("Installing flash-attn...")
178
 
179
- try:
180
- subprocess.run(
181
- ["pip", "install", "flash-attn", "--no-build-isolation"],
182
- check=True,
183
- capture_output=True
184
- )
185
- logging.info("flash-attn installed successfully!")
186
- return True
187
- except subprocess.CalledProcessError:
188
- logging.warning("Failed to install flash-attn via pip, skipping...")
189
- return False
190
 
191
  except Exception as e:
192
  logging.warning(f"Failed to install flash-attn: {e}")
@@ -194,19 +201,27 @@ def install_flash_attn():
194
 
195
  def initialize_system():
196
  optimize_gpu_settings()
197
- has_flash_attn = install_flash_attn()
198
-
199
- from huggingface_hub import snapshot_download
200
 
201
- folder_path = './inference/xcodec_mini_infer'
202
- os.makedirs(folder_path, exist_ok=True)
203
- logging.info(f"Created folder at: {folder_path}")
204
-
205
- snapshot_download(
206
- repo_id="m-a-p/xcodec_mini_infer",
207
- local_dir="./inference/xcodec_mini_infer",
208
- resume_download=True
209
- )
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  try:
212
  os.chdir("./inference")
@@ -215,7 +230,7 @@ def initialize_system():
215
  logging.error(f"Directory error: {e}")
216
  raise
217
 
218
- @lru_cache(maxsize=50)
219
  def get_cached_file_path(content_hash, prefix):
220
  return create_temp_file(content_hash, prefix)
221
 
@@ -247,84 +262,50 @@ def get_last_mp3_file(output_dir):
247
  mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
248
  return mp3_files_with_path[0]
249
 
250
- def optimize_model_selection(lyrics, genre):
251
- model_path = detect_and_select_model(lyrics)
252
- params = calculate_generation_params(lyrics)
253
-
254
- # ์ฝ”๋Ÿฌ์Šค ์กด์žฌ ์—ฌ๋ถ€์— ๋”ฐ๋ฅธ ์„ค์ • ์กฐ์ •
255
- has_chorus = params['sections']['chorus'] > 0
256
-
257
- # ํ† ํฐ ์ˆ˜ ๊ณ„์‚ฐ
258
- tokens_per_segment = params['max_tokens'] // params['num_segments']
259
-
260
- model_config = {
261
- "m-a-p/YuE-s1-7B-anneal-en-cot": {
262
- "max_tokens": params['max_tokens'],
263
- "temperature": 0.8,
264
- "batch_size": 8,
265
- "num_segments": params['num_segments'],
266
- "estimated_duration": params['estimated_duration']
267
- },
268
- "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
269
- "max_tokens": params['max_tokens'],
270
- "temperature": 0.7,
271
- "batch_size": 8,
272
- "num_segments": params['num_segments'],
273
- "estimated_duration": params['estimated_duration']
274
- },
275
- "m-a-p/YuE-s1-7B-anneal-zh-cot": {
276
- "max_tokens": params['max_tokens'],
277
- "temperature": 0.7,
278
- "batch_size": 8,
279
- "num_segments": params['num_segments'],
280
- "estimated_duration": params['estimated_duration']
281
- }
282
- }
283
-
284
- # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ํ† ํฐ ์ˆ˜ ์ฆ๊ฐ€
285
- if has_chorus:
286
- for config in model_config.values():
287
- config['max_tokens'] = int(config['max_tokens'] * 1.5) # 50% ๋” ๋งŽ์€ ํ† ํฐ ํ• ๋‹น
288
-
289
- return model_path, model_config[model_path], params
290
 
291
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
292
  genre_txt_path = None
293
  lyrics_txt_path = None
294
 
295
  try:
296
- # ๋ชจ๋ธ ์„ ํƒ ๋ฐ ์„ค์ •
297
  model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
298
  logging.info(f"Selected model: {model_path}")
299
  logging.info(f"Lyrics analysis: {params}")
300
 
301
- # ์ฝ”๋Ÿฌ์Šค ์„น์…˜ ํ™•์ธ ๋ฐ ๋กœ๊น…
302
  has_chorus = params['sections']['chorus'] > 0
303
  estimated_duration = params.get('estimated_duration', 90)
304
 
305
 
306
- # ํ† ํฐ ์ˆ˜์™€ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ์กฐ์ •
307
  if has_chorus:
308
- actual_max_tokens = min(8000, int(config['max_tokens'] * 1.2)) # 20% ์ฆ๊ฐ€, ์ตœ๋Œ€ 8000
309
- actual_num_segments = 3
310
  else:
311
- actual_max_tokens = config['max_tokens']
312
- actual_num_segments = 2
313
 
314
 
315
-
316
  logging.info(f"Estimated duration: {estimated_duration} seconds")
317
  logging.info(f"Has chorus sections: {has_chorus}")
318
  logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
319
 
320
- # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
321
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
322
  lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
323
 
324
  output_dir = "./output"
325
  os.makedirs(output_dir, exist_ok=True)
326
  empty_output_folder(output_dir)
327
- # ๊ธฐ๋ณธ ๋ช…๋ น์–ด ๊ตฌ์„ฑ
 
328
  command = [
329
  "python", "infer.py",
330
  "--stage1_model", model_path,
@@ -332,19 +313,13 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
332
  "--genre_txt", genre_txt_path,
333
  "--lyrics_txt", lyrics_txt_path,
334
  "--run_n_segments", str(actual_num_segments),
335
- "--stage2_batch_size", "4", # ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ ๊ฐ์†Œ
336
  "--output_dir", output_dir,
337
  "--cuda_idx", "0",
338
- "--max_new_tokens", str(actual_max_tokens)
 
339
  ]
340
 
341
- # GPU ์„ค์ •
342
- if torch.cuda.is_available():
343
- command.append("--disable_offload_model")
344
- # GPU ์„ค์ •
345
-
346
-
347
- # CUDA ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
348
  env = os.environ.copy()
349
  if torch.cuda.is_available():
350
  env.update({
@@ -352,7 +327,8 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
352
  "CUDA_HOME": "/usr/local/cuda",
353
  "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
354
  "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
355
- "PYTORCH_CUDA_ALLOC_CONF": f"max_split_size_mb:512"
 
356
  })
357
 
358
  # transformers ์บ์‹œ ๋งˆ์ด๊ทธ๋ ˆ์ด์…˜ ์ฒ˜๋ฆฌ
@@ -362,7 +338,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
362
  except Exception as e:
363
  logging.warning(f"Cache migration warning (non-critical): {e}")
364
 
365
- # ๋ช…๋ น ์‹คํ–‰
366
  process = subprocess.run(
367
  command,
368
  env=env,
@@ -371,7 +346,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
371
  text=True
372
  )
373
 
374
- # ์‹คํ–‰ ๊ฒฐ๊ณผ ๋กœ๊น…
375
  logging.info(f"Command output: {process.stdout}")
376
  if process.stderr:
377
  logging.error(f"Command error: {process.stderr}")
@@ -381,7 +355,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
381
  logging.error(f"Command: {' '.join(command)}")
382
  raise RuntimeError(f"Inference failed: {process.stderr}")
383
 
384
- # ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
385
  last_mp3 = get_last_mp3_file(output_dir)
386
  if last_mp3:
387
  try:
@@ -391,7 +364,6 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
391
  logging.info(f"Audio duration: {duration:.2f} seconds")
392
  logging.info(f"Expected duration: {estimated_duration} seconds")
393
 
394
- # ์ƒ์„ฑ๋œ ์Œ์•…์ด ๋„ˆ๋ฌด ์งง์€ ๊ฒฝ์šฐ ๊ฒฝ๊ณ 
395
  if duration < estimated_duration * 0.8:
396
  logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
397
  except Exception as e:
@@ -405,27 +377,55 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
405
  logging.error(f"Inference error: {e}")
406
  raise
407
  finally:
408
- # ์ž„์‹œ ํŒŒ์ผ ์ •๋ฆฌ
409
- if genre_txt_path and os.path.exists(genre_txt_path):
410
- try:
411
- os.remove(genre_txt_path)
412
- logging.debug(f"Removed temporary file: {genre_txt_path}")
413
- except Exception as e:
414
- logging.warning(f"Failed to remove temporary file {genre_txt_path}: {e}")
415
-
416
- if lyrics_txt_path and os.path.exists(lyrics_txt_path):
417
- try:
418
- os.remove(lyrics_txt_path)
419
- logging.debug(f"Removed temporary file: {lyrics_txt_path}")
420
- except Exception as e:
421
- logging.warning(f"Failed to remove temporary file {lyrics_txt_path}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  def main():
424
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค
425
  with gr.Blocks() as demo:
426
  with gr.Column():
427
  gr.Markdown("# Open SUNO: Full-Song Generation (Multi-Language Support)")
428
-
429
 
430
  with gr.Row():
431
  with gr.Column():
@@ -462,10 +462,8 @@ def main():
462
  submit_btn = gr.Button("Generate Music", variant="primary")
463
  music_out = gr.Audio(label="Generated Audio")
464
 
465
- # ๋‹ค๊ตญ์–ด ์˜ˆ์ œ
466
  gr.Examples(
467
  examples=[
468
- # ์˜์–ด ์˜ˆ์ œ
469
  [
470
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
471
  """[verse]
@@ -490,36 +488,27 @@ Guiding me back homeward, making my heart rejoice
490
  Don't let this moment fade, hold me close tonight
491
  With you here beside me, everything's alright
492
  Can't imagine life alone, don't want to let you go
493
- Stay with me forever, let our love just flow
494
- """
495
  ],
496
- # ํ•œ๊ตญ์–ด ์˜ˆ์ œ
497
  [
498
  "K-pop bright energetic synth dance electronic",
499
  """[verse]
500
  ์–ธ์  ๊ฐ€ ๋งˆ์ฃผํ•œ ๋ˆˆ๋น› ์†์—์„œ
501
- ์šฐ๋ฆฐ ์„œ๋กœ๋ฅผ ์•Œ์•„๋ณด์•˜์ง€
502
 
503
  [chorus]
504
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
505
- ๋„ˆ์˜ ์ง„์‹ฌ์„ ์ˆจ๊ธฐ์ง€ ๋ง์•„ ์ค˜
506
 
507
  [verse]
508
  ์–ด๋‘์šด ๋ฐค์„ ์ง€๋‚  ๋•Œ๋งˆ๋‹ค
509
- ๋„ˆ์˜ ๋ชฉ์†Œ๋ฆฌ๋ฅผ ๋– ์˜ฌ๋ ค
510
 
511
  [chorus]
512
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
513
- ๋„ˆ์˜ ์ง„์‹ฌ์„ ์ˆจ๊ธฐ์ง€ ๋ง์•„ ์ค˜
514
-
515
-
516
- """
517
  ]
518
  ],
519
  inputs=[genre_txt, lyrics_txt]
520
  )
521
 
522
- # ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™”
523
  initialize_system()
524
 
525
  def update_info(lyrics):
@@ -533,9 +522,6 @@ Stay with me forever, let our love just flow
533
  f"Verses: {sections['verse']}, Chorus: {sections['chorus']} (Expected full length including chorus)"
534
  )
535
 
536
-
537
-
538
- # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
539
  lyrics_txt.change(
540
  fn=update_info,
541
  inputs=[lyrics_txt],
@@ -558,5 +544,5 @@ if __name__ == "__main__":
558
  share=True,
559
  show_api=True,
560
  show_error=True,
561
- max_threads=2
562
- )
 
20
  ]
21
  )
22
 
23
+ def optimize_gpu_settings():
24
+ if torch.cuda.is_available():
25
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์ตœ์ ํ™”
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.benchmark = True
28
+ torch.backends.cudnn.enabled = True
29
+ torch.backends.cudnn.deterministic = False
30
+
31
+ # L40S์— ์ตœ์ ํ™”๋œ ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •
32
+ torch.cuda.empty_cache()
33
+ torch.cuda.set_device(0)
34
+
35
+ # CUDA ์ŠคํŠธ๋ฆผ ์ตœ์ ํ™”
36
+ torch.cuda.Stream(0)
37
+
38
+ # ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ์ตœ์ ํ™”
39
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
40
+
41
+ logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
42
+ logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
43
+
44
+ # L40S ํŠนํ™” ์„ค์ •
45
+ if 'L40S' in torch.cuda.get_device_name(0):
46
+ torch.cuda.set_per_process_memory_fraction(0.95)
47
+
48
  def analyze_lyrics(lyrics, repeat_chorus=2):
49
  lines = [line.strip() for line in lyrics.split('\n') if line.strip()]
50
 
 
61
  'chorus': [],
62
  'bridge': []
63
  }
64
+ last_section = None
65
+
66
+ # ๋งˆ์ง€๋ง‰ ์„น์…˜ ํƒœ๊ทธ ์ฐพ๊ธฐ
67
+ for i, line in enumerate(lines):
68
+ if '[verse]' in line.lower() or '[chorus]' in line.lower() or '[bridge]' in line.lower():
69
+ last_section = i
70
 
71
+ for i, line in enumerate(lines):
72
  lower_line = line.lower()
73
+
74
+ # ์„น์…˜ ํƒœ๊ทธ ์ฒ˜๋ฆฌ
75
  if '[verse]' in lower_line:
76
+ if current_section: # ์ด์ „ ์„น์…˜์˜ ๋ผ์ธ๋“ค ์ €์žฅ
77
+ section_lines[current_section].extend(lines[last_section_start:i])
78
  current_section = 'verse'
79
  sections['verse'] += 1
80
+ last_section_start = i + 1
81
  continue
82
  elif '[chorus]' in lower_line:
83
+ if current_section:
84
+ section_lines[current_section].extend(lines[last_section_start:i])
85
  current_section = 'chorus'
86
  sections['chorus'] += 1
87
+ last_section_start = i + 1
88
  continue
89
  elif '[bridge]' in lower_line:
90
+ if current_section:
91
+ section_lines[current_section].extend(lines[last_section_start:i])
92
  current_section = 'bridge'
93
  sections['bridge'] += 1
94
+ last_section_start = i + 1
95
  continue
96
 
97
+ # ๋งˆ์ง€๋ง‰ ์„น์…˜์˜ ๋ผ์ธ๋“ค ์ถ”๊ฐ€
98
+ if current_section and last_section_start < len(lines):
99
+ section_lines[current_section].extend(lines[last_section_start:])
100
 
101
+ # ์ฝ”๋Ÿฌ์Šค ๋ฐ˜๋ณต ์ฒ˜๋ฆฌ
102
+ if sections['chorus'] > 0 and repeat_chorus > 1:
103
+ original_chorus = section_lines['chorus'][:]
 
104
  for _ in range(repeat_chorus - 1):
105
+ section_lines['chorus'].extend(original_chorus)
106
+
107
+ # ์„น์…˜๋ณ„ ๋ผ์ธ ์ˆ˜ ํ™•์ธ ๋กœ๊น…
108
+ logging.info(f"Section line counts - Verse: {len(section_lines['verse'])}, "
109
+ f"Chorus: {len(section_lines['chorus'])}, "
110
+ f"Bridge: {len(section_lines['bridge'])}")
 
 
 
 
111
 
112
+ return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), len(lines), section_lines
113
 
114
  def calculate_generation_params(lyrics):
115
  sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics)
 
121
  'bridge': 5 # bridge๋Š” ํ•œ ์ค„๋‹น 5์ดˆ
122
  }
123
 
124
+ # ๊ฐ ์„น์…˜๋ณ„ ์˜ˆ์ƒ ์‹œ๊ฐ„ ๊ณ„์‚ฐ (๋งˆ์ง€๋ง‰ ์„น์…˜ ํฌํ•จ)
125
+ section_durations = {}
126
+ for section_type in ['verse', 'chorus', 'bridge']:
127
+ lines_count = len(section_lines[section_type])
128
+ section_durations[section_type] = lines_count * time_per_line[section_type]
 
129
 
130
+ # ์ „๏ฟฝ๏ฟฝ ์‹œ๊ฐ„ ๊ณ„์‚ฐ (์—ฌ์œ  ์‹œ๊ฐ„ ์ถ”๊ฐ€)
131
+ total_duration = sum(duration for duration in section_durations.values())
132
+ total_duration = max(60, int(total_duration * 1.2)) # 20% ์—ฌ์œ  ์‹œ๊ฐ„ ์ถ”๊ฐ€
133
 
134
+ # ํ† ํฐ ๊ณ„์‚ฐ (๋งˆ์ง€๋ง‰ ์„น์…˜์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํ† ํฐ)
135
+ base_tokens = 3000
136
+ tokens_per_line = 200
137
+ extra_tokens = 1000 # ๋งˆ์ง€๋ง‰ ์„น์…˜์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํ† ํฐ
138
 
139
+ total_tokens = base_tokens + (total_lines * tokens_per_line) + extra_tokens
140
 
141
+ # ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ (๋งˆ์ง€๋ง‰ ์„น์…˜์„ ์œ„ํ•œ ์ถ”๊ฐ€ ์„ธ๊ทธ๋จผํŠธ)
142
  if sections['chorus'] > 0:
143
+ num_segments = 4 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ 4๊ฐœ ์„ธ๊ทธ๋จผํŠธ
144
  else:
145
+ num_segments = 3 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ
146
 
147
+ # ํ† ํฐ ์ˆ˜ ์ œํ•œ (๋” ํฐ ์ œํ•œ)
148
+ max_tokens = min(12000, total_tokens) # ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜ ์ฆ๊ฐ€
149
 
150
  return {
151
  'max_tokens': max_tokens,
 
157
  'has_chorus': sections['chorus'] > 0
158
  }
159
 
 
 
 
 
 
 
 
 
 
 
160
  def detect_and_select_model(text):
161
+ if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text):
162
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
163
+ elif re.search(r'[\u4e00-\u9fff]', text):
164
  return "m-a-p/YuE-s1-7B-anneal-zh-cot"
165
+ elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text):
166
  return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  else:
168
+ return "m-a-p/YuE-s1-7B-anneal-en-cot"
169
 
170
  def install_flash_attn():
171
  try:
 
187
  except ImportError:
188
  logging.info("Installing flash-attn...")
189
 
190
+ subprocess.run(
191
+ ["pip", "install", "flash-attn", "--no-build-isolation"],
192
+ check=True,
193
+ capture_output=True
194
+ )
195
+ logging.info("flash-attn installed successfully!")
196
+ return True
 
 
 
 
197
 
198
  except Exception as e:
199
  logging.warning(f"Failed to install flash-attn: {e}")
 
201
 
202
  def initialize_system():
203
  optimize_gpu_settings()
 
 
 
204
 
205
+ with ThreadPoolExecutor(max_workers=4) as executor:
206
+ futures = []
207
+
208
+ futures.append(executor.submit(install_flash_attn))
209
+
210
+ from huggingface_hub import snapshot_download
211
+
212
+ folder_path = './inference/xcodec_mini_infer'
213
+ os.makedirs(folder_path, exist_ok=True)
214
+ logging.info(f"Created folder at: {folder_path}")
215
+
216
+ futures.append(executor.submit(
217
+ snapshot_download,
218
+ repo_id="m-a-p/xcodec_mini_infer",
219
+ local_dir="./inference/xcodec_mini_infer",
220
+ resume_download=True
221
+ ))
222
+
223
+ for future in futures:
224
+ future.result()
225
 
226
  try:
227
  os.chdir("./inference")
 
230
  logging.error(f"Directory error: {e}")
231
  raise
232
 
233
+ @lru_cache(maxsize=100)
234
  def get_cached_file_path(content_hash, prefix):
235
  return create_temp_file(content_hash, prefix)
236
 
 
262
  mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
263
  return mp3_files_with_path[0]
264
 
265
+ def get_audio_duration(file_path):
266
+ try:
267
+ import librosa
268
+ duration = librosa.get_duration(path=file_path)
269
+ return duration
270
+ except Exception as e:
271
+ logging.error(f"Failed to get audio duration: {e}")
272
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
275
  genre_txt_path = None
276
  lyrics_txt_path = None
277
 
278
  try:
 
279
  model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
280
  logging.info(f"Selected model: {model_path}")
281
  logging.info(f"Lyrics analysis: {params}")
282
 
 
283
  has_chorus = params['sections']['chorus'] > 0
284
  estimated_duration = params.get('estimated_duration', 90)
285
 
286
 
287
+ # ์„ธ๊ทธ๋จผํŠธ ๋ฐ ํ† ํฐ ์ˆ˜ ์„ค์ •
288
  if has_chorus:
289
+ actual_max_tokens = min(12000, int(config['max_tokens'] * 1.3)) # 30% ๋” ๋งŽ์€ ํ† ํฐ
290
+ actual_num_segments = min(5, params['num_segments'] + 2) # ์ถ”๊ฐ€ ์„ธ๊ทธ๋จผํŠธ
291
  else:
292
+ actual_max_tokens = min(10000, int(config['max_tokens'] * 1.2))
293
+ actual_num_segments = min(4, params['num_segments'] + 1)
294
 
295
 
296
+
297
  logging.info(f"Estimated duration: {estimated_duration} seconds")
298
  logging.info(f"Has chorus sections: {has_chorus}")
299
  logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
300
 
 
301
  genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
302
  lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
303
 
304
  output_dir = "./output"
305
  os.makedirs(output_dir, exist_ok=True)
306
  empty_output_folder(output_dir)
307
+
308
+ # ์ˆ˜์ •๋œ command - ์ง€์›๋˜์ง€ ์•Š๋Š” ์ธ์ˆ˜ ์ œ๊ฑฐ
309
  command = [
310
  "python", "infer.py",
311
  "--stage1_model", model_path,
 
313
  "--genre_txt", genre_txt_path,
314
  "--lyrics_txt", lyrics_txt_path,
315
  "--run_n_segments", str(actual_num_segments),
316
+ "--stage2_batch_size", "16",
317
  "--output_dir", output_dir,
318
  "--cuda_idx", "0",
319
+ "--max_new_tokens", str(actual_max_tokens),
320
+ "--disable_offload_model" # GPU ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•ด ์ถ”๊ฐ€
321
  ]
322
 
 
 
 
 
 
 
 
323
  env = os.environ.copy()
324
  if torch.cuda.is_available():
325
  env.update({
 
327
  "CUDA_HOME": "/usr/local/cuda",
328
  "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
329
  "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
330
+ "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512",
331
+ "CUDA_LAUNCH_BLOCKING": "0"
332
  })
333
 
334
  # transformers ์บ์‹œ ๋งˆ์ด๊ทธ๋ ˆ์ด์…˜ ์ฒ˜๋ฆฌ
 
338
  except Exception as e:
339
  logging.warning(f"Cache migration warning (non-critical): {e}")
340
 
 
341
  process = subprocess.run(
342
  command,
343
  env=env,
 
346
  text=True
347
  )
348
 
 
349
  logging.info(f"Command output: {process.stdout}")
350
  if process.stderr:
351
  logging.error(f"Command error: {process.stderr}")
 
355
  logging.error(f"Command: {' '.join(command)}")
356
  raise RuntimeError(f"Inference failed: {process.stderr}")
357
 
 
358
  last_mp3 = get_last_mp3_file(output_dir)
359
  if last_mp3:
360
  try:
 
364
  logging.info(f"Audio duration: {duration:.2f} seconds")
365
  logging.info(f"Expected duration: {estimated_duration} seconds")
366
 
 
367
  if duration < estimated_duration * 0.8:
368
  logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
369
  except Exception as e:
 
377
  logging.error(f"Inference error: {e}")
378
  raise
379
  finally:
380
+ for path in [genre_txt_path, lyrics_txt_path]:
381
+ if path and os.path.exists(path):
382
+ try:
383
+ os.remove(path)
384
+ logging.debug(f"Removed temporary file: {path}")
385
+ except Exception as e:
386
+ logging.warning(f"Failed to remove temporary file {path}: {e}")
387
+
388
+ def optimize_model_selection(lyrics, genre):
389
+ model_path = detect_and_select_model(lyrics)
390
+ params = calculate_generation_params(lyrics)
391
+
392
+ has_chorus = params['sections']['chorus'] > 0
393
+ tokens_per_segment = params['max_tokens'] // params['num_segments']
394
+
395
+ model_config = {
396
+ "m-a-p/YuE-s1-7B-anneal-en-cot": {
397
+ "max_tokens": params['max_tokens'],
398
+ "temperature": 0.8,
399
+ "batch_size": 16,
400
+ "num_segments": params['num_segments'],
401
+ "estimated_duration": params['estimated_duration']
402
+ },
403
+ "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
404
+ "max_tokens": params['max_tokens'],
405
+ "temperature": 0.7,
406
+ "batch_size": 16,
407
+ "num_segments": params['num_segments'],
408
+ "estimated_duration": params['estimated_duration']
409
+ },
410
+ "m-a-p/YuE-s1-7B-anneal-zh-cot": {
411
+ "max_tokens": params['max_tokens'],
412
+ "temperature": 0.7,
413
+ "batch_size": 16,
414
+ "num_segments": params['num_segments'],
415
+ "estimated_duration": params['estimated_duration']
416
+ }
417
+ }
418
+
419
+ if has_chorus:
420
+ for config in model_config.values():
421
+ config['max_tokens'] = int(config['max_tokens'] * 1.5)
422
+
423
+ return model_path, model_config[model_path], params
424
 
425
  def main():
 
426
  with gr.Blocks() as demo:
427
  with gr.Column():
428
  gr.Markdown("# Open SUNO: Full-Song Generation (Multi-Language Support)")
 
429
 
430
  with gr.Row():
431
  with gr.Column():
 
462
  submit_btn = gr.Button("Generate Music", variant="primary")
463
  music_out = gr.Audio(label="Generated Audio")
464
 
 
465
  gr.Examples(
466
  examples=[
 
467
  [
468
  "female blues airy vocal bright vocal piano sad romantic guitar jazz",
469
  """[verse]
 
488
  Don't let this moment fade, hold me close tonight
489
  With you here beside me, everything's alright
490
  Can't imagine life alone, don't want to let you go
491
+ Stay with me forever, let our love just flow"""
 
492
  ],
 
493
  [
494
  "K-pop bright energetic synth dance electronic",
495
  """[verse]
496
  ์–ธ์  ๊ฐ€ ๋งˆ์ฃผํ•œ ๋ˆˆ๋น› ์†์—์„œ
 
497
 
498
  [chorus]
499
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
 
500
 
501
  [verse]
502
  ์–ด๋‘์šด ๋ฐค์„ ์ง€๋‚  ๋•Œ๋งˆ๋‹ค
 
503
 
504
  [chorus]
505
  ๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
506
+ """
 
 
 
507
  ]
508
  ],
509
  inputs=[genre_txt, lyrics_txt]
510
  )
511
 
 
512
  initialize_system()
513
 
514
  def update_info(lyrics):
 
522
  f"Verses: {sections['verse']}, Chorus: {sections['chorus']} (Expected full length including chorus)"
523
  )
524
 
 
 
 
525
  lyrics_txt.change(
526
  fn=update_info,
527
  inputs=[lyrics_txt],
 
544
  share=True,
545
  show_api=True,
546
  show_error=True,
547
+ max_threads=8
548
+ )