ZeyuXie commited on
Commit
ae95272
·
verified ·
1 Parent(s): ef76a0d

Update pico_model.py

Browse files
Files changed (1) hide show
  1. pico_model.py +5 -57
pico_model.py CHANGED
@@ -8,40 +8,6 @@ import torch.nn.functional as F
8
  from diffusers.utils.torch_utils import randn_tensor
9
  from diffusers import DDPMScheduler, UNet2DConditionModel
10
 
11
- from audioldm.audio.stft import TacotronSTFT
12
- from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
13
- from audioldm.utils import default_audioldm_config, get_metadata
14
-
15
-
16
-
17
- def build_pretrained_models(name):
18
- checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
19
- scale_factor = checkpoint["state_dict"]["scale_factor"].item()
20
-
21
- vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
22
-
23
- config = default_audioldm_config(name)
24
- vae_config = config["model"]["params"]["first_stage_config"]["params"]
25
- vae_config["scale_factor"] = scale_factor
26
-
27
- vae = AutoencoderKL(**vae_config)
28
- vae.load_state_dict(vae_state_dict)
29
-
30
- fn_STFT = TacotronSTFT(
31
- config["preprocessing"]["stft"]["filter_length"],
32
- config["preprocessing"]["stft"]["hop_length"],
33
- config["preprocessing"]["stft"]["win_length"],
34
- config["preprocessing"]["mel"]["n_mel_channels"],
35
- config["preprocessing"]["audio"]["sampling_rate"],
36
- config["preprocessing"]["mel"]["mel_fmin"],
37
- config["preprocessing"]["mel"]["mel_fmax"],
38
- )
39
-
40
- vae.eval()
41
- fn_STFT.eval()
42
-
43
- return vae, fn_STFT
44
-
45
  def _init_layer(layer):
46
  """Initialize a Linear or Convolutional layer. """
47
  nn.init.xavier_uniform_(layer.weight)
@@ -243,7 +209,7 @@ class ClapText_Onset_2_Audio_Diffusion(nn.Module):
243
  from sklearn.metrics.pairwise import cosine_similarity
244
  import laion_clap
245
  from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
246
-
247
  class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
248
  def __init__(self,
249
  scheduler_name,
@@ -260,31 +226,12 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
260
  ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
261
  del_parameter_key = ["text_branch.embeddings.position_ids"]
262
  ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
263
- diffusion_ckpt = torch.load(diffusion_pt)
264
  del diffusion_ckpt["class_emb.weight"]
265
  ckpt.update(diffusion_ckpt)
266
  self.load_state_dict(ckpt)
267
 
268
- self.event_list = [
269
- "burping_belching", # 0
270
- "car_horn_honking", #
271
- "cat_meowing", #
272
- "cow_mooing", #
273
- "dog_barking", #
274
- "door_knocking", #
275
- "door_slamming", #
276
- "explosion", #
277
- "gunshot", # 8
278
- "sheep_goat_bleating", #
279
- "sneeze", #
280
- "spraying", #
281
- "thump_thud", #
282
- "train_horn", #
283
- "tapping_clicking_clanking", #
284
- "woman_laughing", #
285
- "duck_quacking", # 16
286
- "whistling", #
287
- ]
288
  self.events_emb = self.freeze_text_encoder.get_text_embedding(self.event_list, use_tensor=False)
289
 
290
 
@@ -300,10 +247,11 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
300
  for event_timestamp in timestampCaption.split(' and '):
301
  # event_timestamp : event1__onset1-offset1_onset2-offset2
302
  (event, instance) = event_timestamp.split(' at ')
303
- events.append(event)
304
  # instance : onset1-offset1_onset2-offset2
305
  event_emb = self.freeze_text_encoder.get_text_embedding([event, ""], use_tensor=False)[0]
306
  event_id = np.argmax(cosine_similarity(event_emb.reshape(1, -1), self.events_emb))
 
307
  for start_end in instance.split('_'):
308
  (start, end) = start_end.split('-')
309
  start, end = int(float(start)*250/10), int(float(end)*250/10)
 
8
  from diffusers.utils.torch_utils import randn_tensor
9
  from diffusers import DDPMScheduler, UNet2DConditionModel
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def _init_layer(layer):
12
  """Initialize a Linear or Convolutional layer. """
13
  nn.init.xavier_uniform_(layer.weight)
 
209
  from sklearn.metrics.pairwise import cosine_similarity
210
  import laion_clap
211
  from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
212
+ from llm_preprocess import get_event
213
  class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
214
  def __init__(self,
215
  scheduler_name,
 
226
  ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
227
  del_parameter_key = ["text_branch.embeddings.position_ids"]
228
  ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
229
+ diffusion_ckpt = torch.load(diffusion_pt, map_location=self.device)
230
  del diffusion_ckpt["class_emb.weight"]
231
  ckpt.update(diffusion_ckpt)
232
  self.load_state_dict(ckpt)
233
 
234
+ self.event_list = get_event()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  self.events_emb = self.freeze_text_encoder.get_text_embedding(self.event_list, use_tensor=False)
236
 
237
 
 
247
  for event_timestamp in timestampCaption.split(' and '):
248
  # event_timestamp : event1__onset1-offset1_onset2-offset2
249
  (event, instance) = event_timestamp.split(' at ')
250
+
251
  # instance : onset1-offset1_onset2-offset2
252
  event_emb = self.freeze_text_encoder.get_text_embedding([event, ""], use_tensor=False)[0]
253
  event_id = np.argmax(cosine_similarity(event_emb.reshape(1, -1), self.events_emb))
254
+ events.append(self.event_list[event_id])
255
  for start_end in instance.split('_'):
256
  (start, end) = start_end.split('-')
257
  start, end = int(float(start)*250/10), int(float(end)*250/10)