Spaces:
Running
on
Zero
Running
on
Zero
Update pico_model.py
Browse files- pico_model.py +5 -57
pico_model.py
CHANGED
@@ -8,40 +8,6 @@ import torch.nn.functional as F
|
|
8 |
from diffusers.utils.torch_utils import randn_tensor
|
9 |
from diffusers import DDPMScheduler, UNet2DConditionModel
|
10 |
|
11 |
-
from audioldm.audio.stft import TacotronSTFT
|
12 |
-
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
|
13 |
-
from audioldm.utils import default_audioldm_config, get_metadata
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
def build_pretrained_models(name):
|
18 |
-
checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
|
19 |
-
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
|
20 |
-
|
21 |
-
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
|
22 |
-
|
23 |
-
config = default_audioldm_config(name)
|
24 |
-
vae_config = config["model"]["params"]["first_stage_config"]["params"]
|
25 |
-
vae_config["scale_factor"] = scale_factor
|
26 |
-
|
27 |
-
vae = AutoencoderKL(**vae_config)
|
28 |
-
vae.load_state_dict(vae_state_dict)
|
29 |
-
|
30 |
-
fn_STFT = TacotronSTFT(
|
31 |
-
config["preprocessing"]["stft"]["filter_length"],
|
32 |
-
config["preprocessing"]["stft"]["hop_length"],
|
33 |
-
config["preprocessing"]["stft"]["win_length"],
|
34 |
-
config["preprocessing"]["mel"]["n_mel_channels"],
|
35 |
-
config["preprocessing"]["audio"]["sampling_rate"],
|
36 |
-
config["preprocessing"]["mel"]["mel_fmin"],
|
37 |
-
config["preprocessing"]["mel"]["mel_fmax"],
|
38 |
-
)
|
39 |
-
|
40 |
-
vae.eval()
|
41 |
-
fn_STFT.eval()
|
42 |
-
|
43 |
-
return vae, fn_STFT
|
44 |
-
|
45 |
def _init_layer(layer):
|
46 |
"""Initialize a Linear or Convolutional layer. """
|
47 |
nn.init.xavier_uniform_(layer.weight)
|
@@ -243,7 +209,7 @@ class ClapText_Onset_2_Audio_Diffusion(nn.Module):
|
|
243 |
from sklearn.metrics.pairwise import cosine_similarity
|
244 |
import laion_clap
|
245 |
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
246 |
-
|
247 |
class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
248 |
def __init__(self,
|
249 |
scheduler_name,
|
@@ -260,31 +226,12 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
|
260 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
261 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
262 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
263 |
-
diffusion_ckpt = torch.load(diffusion_pt)
|
264 |
del diffusion_ckpt["class_emb.weight"]
|
265 |
ckpt.update(diffusion_ckpt)
|
266 |
self.load_state_dict(ckpt)
|
267 |
|
268 |
-
self.event_list =
|
269 |
-
"burping_belching", # 0
|
270 |
-
"car_horn_honking", #
|
271 |
-
"cat_meowing", #
|
272 |
-
"cow_mooing", #
|
273 |
-
"dog_barking", #
|
274 |
-
"door_knocking", #
|
275 |
-
"door_slamming", #
|
276 |
-
"explosion", #
|
277 |
-
"gunshot", # 8
|
278 |
-
"sheep_goat_bleating", #
|
279 |
-
"sneeze", #
|
280 |
-
"spraying", #
|
281 |
-
"thump_thud", #
|
282 |
-
"train_horn", #
|
283 |
-
"tapping_clicking_clanking", #
|
284 |
-
"woman_laughing", #
|
285 |
-
"duck_quacking", # 16
|
286 |
-
"whistling", #
|
287 |
-
]
|
288 |
self.events_emb = self.freeze_text_encoder.get_text_embedding(self.event_list, use_tensor=False)
|
289 |
|
290 |
|
@@ -300,10 +247,11 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
|
300 |
for event_timestamp in timestampCaption.split(' and '):
|
301 |
# event_timestamp : event1__onset1-offset1_onset2-offset2
|
302 |
(event, instance) = event_timestamp.split(' at ')
|
303 |
-
|
304 |
# instance : onset1-offset1_onset2-offset2
|
305 |
event_emb = self.freeze_text_encoder.get_text_embedding([event, ""], use_tensor=False)[0]
|
306 |
event_id = np.argmax(cosine_similarity(event_emb.reshape(1, -1), self.events_emb))
|
|
|
307 |
for start_end in instance.split('_'):
|
308 |
(start, end) = start_end.split('-')
|
309 |
start, end = int(float(start)*250/10), int(float(end)*250/10)
|
|
|
8 |
from diffusers.utils.torch_utils import randn_tensor
|
9 |
from diffusers import DDPMScheduler, UNet2DConditionModel
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def _init_layer(layer):
|
12 |
"""Initialize a Linear or Convolutional layer. """
|
13 |
nn.init.xavier_uniform_(layer.weight)
|
|
|
209 |
from sklearn.metrics.pairwise import cosine_similarity
|
210 |
import laion_clap
|
211 |
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
|
212 |
+
from llm_preprocess import get_event
|
213 |
class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
214 |
def __init__(self,
|
215 |
scheduler_name,
|
|
|
226 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
227 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
228 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
229 |
+
diffusion_ckpt = torch.load(diffusion_pt, map_location=self.device)
|
230 |
del diffusion_ckpt["class_emb.weight"]
|
231 |
ckpt.update(diffusion_ckpt)
|
232 |
self.load_state_dict(ckpt)
|
233 |
|
234 |
+
self.event_list = get_event()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
self.events_emb = self.freeze_text_encoder.get_text_embedding(self.event_list, use_tensor=False)
|
236 |
|
237 |
|
|
|
247 |
for event_timestamp in timestampCaption.split(' and '):
|
248 |
# event_timestamp : event1__onset1-offset1_onset2-offset2
|
249 |
(event, instance) = event_timestamp.split(' at ')
|
250 |
+
|
251 |
# instance : onset1-offset1_onset2-offset2
|
252 |
event_emb = self.freeze_text_encoder.get_text_embedding([event, ""], use_tensor=False)[0]
|
253 |
event_id = np.argmax(cosine_similarity(event_emb.reshape(1, -1), self.events_emb))
|
254 |
+
events.append(self.event_list[event_id])
|
255 |
for start_end in instance.split('_'):
|
256 |
(start, end) = start_end.split('-')
|
257 |
start, end = int(float(start)*250/10), int(float(end)*250/10)
|