Make-An-Audio-3 / scripts /txt2audio_for_2cap_flow.py
3v324v23's picture
Add code
a84a65c
import argparse, os, sys, glob
import pathlib
directory = pathlib.Path(os.getcwd())
print(directory)
sys.path.append(str(directory))
import torch
import numpy as np
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
import pandas as pd
from tqdm import tqdm
import preprocess.n2s_by_openai as n2s
from vocoder.bigvgan.models import VocoderBigVGAN
import soundfile
import torchaudio, math
def load_model_from_config(config, ckpt = None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt:
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
print(f"Note chat no ckpt is loaded !!!")
model.cuda()
model.eval()
return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
# default="A large truck driving by as an emergency siren wails and truck horn honks",
default='This instrumental song features a relaxing melody with a country feel, accompanied by a guitar, piano, simple percussion, and bass in a slow tempo',
help="the prompt to generate"
)
parser.add_argument(
"--sample_rate",
type=int,
default="16000",
help="sample rate of wav"
)
parser.add_argument(
"--test-dataset",
default="none",
help="test which dataset: testset"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2audio-samples"
)
parser.add_argument(
"--ddim_steps",
type=int,
default=25,
help="number of ddim sampling steps",
)
parser.add_argument(
"--n_iter",
type=int,
default=1,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=20, # keep fix
help="latent height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=312, # keep fix
help="latent width, in pixel space",
)
parser.add_argument(
"--n_samples",
type=int,
default=1,
help="how many samples to produce for the given prompt",
)
parser.add_argument(
"--scale",
type=float,
default=5.0, # if it's 1, only condition is taken into consideration
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"-r",
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
type=str,
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default="",
)
parser.add_argument(
"--vocoder-ckpt",
type=str,
help="paths to vocoder checkpoint",
default='vocoder/logs/audioset',
)
return parser.parse_args()
class GenSamples:
def __init__(self,opt, model,outpath,config, vocoder = None,save_mel = True,save_wav = True) -> None:
self.opt = opt
self.model = model
self.outpath = outpath
if save_wav:
assert vocoder is not None
self.vocoder = vocoder
self.save_mel = save_mel
self.save_wav = save_wav
self.channel_dim = self.model.channels
self.config = config
def gen_test_sample(self,prompt, mel_name = None,wav_name = None, gt=None, video=None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'}
uc = None
record_dicts = []
if self.opt.scale != 1.0:
try: # audiocaps
uc = self.model.get_learned_conditioning({'ori_caption': "",'struct_caption': ""})
except: # audioset
uc = self.model.get_learned_conditioning(prompt['ori_caption'])
for n in range(self.opt.n_iter):# trange(self.opt.n_iter, desc="Sampling"):
try: # audiocaps
c = self.model.get_learned_conditioning(prompt) # shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding
except: # audioset
c = self.model.get_learned_conditioning(prompt['ori_caption'])
if self.channel_dim>0:
shape = [self.channel_dim, self.opt.H, self.opt.W] # (z_dim, 80//2^x, 848//2^x)
else:
shape = [1, self.opt.H, self.opt.W]
x0 = torch.randn(shape, device=self.model.device)
if self.opt.scale == 1: # w/o cfg
sample, _ = self.model.sample(c, 1, timesteps=self.opt.ddim_steps, x_latent=x0)
else: # cfg
sample, _ = self.model.sample_cfg(c, self.opt.scale, uc, 1, timesteps=self.opt.ddim_steps, x_latent=x0)
x_samples_ddim = self.model.decode_first_stage(sample)
for idx,spec in enumerate(x_samples_ddim):
spec = spec.squeeze(0).cpu().numpy()
record_dict = {'caption':prompt['ori_caption'][0]}
if self.save_mel:
mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy')
np.save(mel_path,spec)
record_dict['mel_path'] = mel_path
if self.save_wav:
wav = self.vocoder.vocode(spec)
wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav')
soundfile.write(wav_path, wav, self.opt.sample_rate)
record_dict['audio_path'] = wav_path
record_dicts.append(record_dict)
# if gt != None:
# wav_gt = self.vocoder.vocode(gt)
# wav_path = os.path.join(self.outpath, wav_name + f'_gt.wav')
# soundfile.write(wav_path, wav_gt, 16000)
return record_dicts
def main():
opt = parse_args()
# torch.manual_seed(55)
config = OmegaConf.load(opt.base)
model = load_model_from_config(config, opt.resume)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
os.makedirs(opt.outdir, exist_ok=True)
vocoder = VocoderBigVGAN(opt.vocoder_ckpt,device)
generator = GenSamples(opt, model,opt.outdir,config, vocoder,save_mel = False,save_wav = True)
csv_dicts = []
with torch.no_grad():
with model.ema_scope():
if opt.test_dataset != 'none':
if opt.test_dataset == 'testset':
test_dataset = instantiate_from_config(config['test_dataset'])
video = None
else:
raise NotImplementedError
print(f"Dataset: {type(test_dataset)} LEN: {len(test_dataset)}")
temp_n = 0
int_s = 0
for item in tqdm(test_dataset):
int_s += 1
if int_s < 2:
continue
# int_s += 1
prompt,f_name, gt = item['caption'],item['f_name'],item['image']
vname_num_split_index = f_name.rfind('_')# file_names[b]:video_name+'_'+num
v_n,num = f_name[:vname_num_split_index],f_name[vname_num_split_index+1:]
mel_name = f'{v_n}_sample_{num}'
wav_name = f'{v_n}_sample_{num}'
# write_gt_wav(v_n,opt.test_dataset2,opt.outdir,opt.sample_rate)
csv_dicts.extend(generator.gen_test_sample(prompt, mel_name=mel_name ,wav_name=wav_name, gt=gt, video=video))
if temp_n > 1:
break
temp_n += 1
df = pd.DataFrame.from_dict(csv_dicts)
df.to_csv(os.path.join(opt.outdir,'result.csv'),sep='\t',index=False)
else:
ori_caption = opt.prompt
struct_caption = n2s.get_struct(ori_caption)
# struct_caption = f'<{ori_caption}& all>'
print(f"The structed caption by Chatgpt is : {struct_caption}")
wav_name = f'{ori_caption.strip().replace(" ", "-")}'
prompt = {'ori_caption':[ori_caption],'struct_caption':[struct_caption]}
generator.gen_test_sample(prompt, wav_name=wav_name)
print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.")
if __name__ == "__main__":
main()