import torch from torch.utils.data import Dataset import numpy as np from tqdm import tqdm import random import logging from os.path import join, exists from third_party.cgdetr.utils.basic_utils import load_jsonl, l2_normalize_np_array from third_party.cgdetr.utils.tensor_utils import pad_sequences_1d from third_party.cgdetr.cg_detr.span_utils import span_xx_to_cxw # from torchtext import vocab import torch.nn as nn logger = logging.getLogger(__name__) class StartEndDataset(Dataset): Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] """One line in data loaded from data_path." { "qid": 7803, "query": "Man in gray top walks from outside to inside.", "duration": 150, "vid": "RoripwjYFp8_360.0_510.0", "relevant_clip_ids": [13, 14, 15, 16, 17], "relevant_windows": [[26, 36]] } """ def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, q_feat_type="last_hidden_state", max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", normalize_v=True, normalize_t=True, load_labels=True, clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0, dset_domain=None): self.dset_name = dset_name self.data_path = data_path self.data_ratio = data_ratio self.v_feat_dirs = v_feat_dirs \ if isinstance(v_feat_dirs, list) else [v_feat_dirs] self.q_feat_dir = q_feat_dir self.q_feat_type = q_feat_type if max_v_l == -1: max_v_l = 100000000 if max_q_l == -1: max_q_l = 100 self.max_q_l = max_q_l self.max_v_l = max_v_l self.ctx_mode = ctx_mode self.use_tef = "tef" in ctx_mode self.use_video = "video" in ctx_mode self.normalize_t = normalize_t self.normalize_v = normalize_v self.load_labels = load_labels self.clip_len = clip_len self.max_windows = max_windows # maximum number of windows to use as labels self.span_loss_type = span_loss_type self.txt_drop_ratio = txt_drop_ratio if "val" in data_path or "test" in data_path: assert txt_drop_ratio == 0 if self.dset_name == 'hl': self.max_q_l = 32 self.max_v_l = 75 self.clip_len = 2 # checks assert q_feat_type in self.Q_FEAT_TYPES # data self.data = self.load_data() self.use_glove = False self.use_glove = 'vgg' in self.v_feat_dirs[0] # if self.dset_name == 'charadesSTA' and self.use_glove: # self.vocab = vocab.pretrained_aliases['glove.6B.300d']() # self.vocab.itos.extend(['']) # self.vocab.stoi[''] = self.vocab.vectors.shape[0] # self.vocab.vectors = torch.cat( # (self.vocab.vectors, torch.zeros(1, self.vocab.dim)), dim=0) # self.embedding = nn.Embedding.from_pretrained(self.vocab.vectors) def load_data(self): datalist = load_jsonl(self.data_path) if self.data_ratio != 1: n_examples = int(len(datalist) * self.data_ratio) datalist = datalist[:n_examples] logger.info("Using {}% of the data: {} examples" .format(self.data_ratio * 100, n_examples)) return datalist def __len__(self): return len(self.data) def __getitem__(self, index): meta = self.data[index] model_inputs = dict() if self.use_glove: # False model_inputs["query_feat"] = self.get_query(meta["query"]) else: model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq) # [16, 4096] if self.use_video : # True model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv) ctx_l = len(model_inputs["video_feat"]) else: ctx_l = self.max_v_l if self.use_tef: tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l tef_ed = tef_st + 1.0 / ctx_l tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) if self.use_video : model_inputs["video_feat"] = torch.cat( [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2) else: model_inputs["video_feat"] = tef if "relevant_windows" in meta: ## For Qvhighlights test set model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2) if self.dset_name in ['charadesSTA', 'tacos', 'activitynet']: ## charades, tacos, nlq model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt elif "subs_train" not in self.data_path: model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) else: model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt if 'qvhighlight' or 'qvhl' in self.data_path: model_inputs["relevant_clip_ids"] = meta["relevant_clip_ids"] model_inputs["vid"] = meta["vid"] model_inputs["qid"] = meta["qid"] return dict(meta=meta, model_inputs=model_inputs) # def get_query(self, query): # word_inds = torch.LongTensor( # [self.vocab.stoi.get(w.lower(), 400000) for w in query.split()]) # return self.embedding(word_inds) def get_query(self, query): print("ERROR") exit() def get_saliency_labels_sub_as_query(self, gt_window, duration, ctx_l, max_n=2): clip_len = duration / ctx_l gt_st = int(gt_window[0] / clip_len) gt_ed = max(0, min(int(gt_window[1] / clip_len), ctx_l) - 1) if gt_st > gt_ed: gt_st = gt_ed if gt_st != gt_ed: pos_clip_indices = random.sample(range(gt_st, gt_ed + 1), k=max_n) # 在GT frame idx中随机选两个 else: if self.dset_name == 'nlq': pos_clip_indices = [gt_st] * 2 else: pos_clip_indices = [gt_st, gt_st] neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) # 非GT的frame idx try: neg_clip_indices = random.sample(neg_pool, k=max_n) # 在非GT frame idx中随机选两个 except: neg_clip_indices = pos_clip_indices # For charades_sta score_array = np.zeros(ctx_l) score_array[gt_st:gt_ed + 1] = 1 return pos_clip_indices, neg_clip_indices, score_array def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True): """Sum the scores from the three annotations, then take the two clips with the maximum scores as positive, and two with the minimum scores as negative. Args: rel_clip_ids: list(int), list of relevant clip ids scores: list([anno1_score, anno2_score, anno3_score]), ctx_l: int max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. """ # indices inside rel_clip_ids scores = np.array(scores) # (#rel_clips, 3) agg_scores = np.sum(scores, 1) # (#rel_clips, ) sort_indices = np.argsort(agg_scores) # increasing # indices in the whole video # the min(_, ctx_l-1) here is incorrect, but should not cause # much troubles since this should be rarely used. hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] easy_pos_clip_indices = [] easy_neg_clip_indices = [] if add_easy_negative: easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) if len(easy_neg_pool) >= max_n: easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) else: # copy the hard ones easy_pos_clip_indices = hard_pos_clip_indices easy_neg_clip_indices = hard_neg_clip_indices pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices return pos_clip_indices, neg_clip_indices def get_saliency_labels_all(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True): """Sum the scores from the three annotations, then take the two clips with the maximum scores as positive, and two with the minimum scores as negative. Args: rel_clip_ids: list(int), list of relevant clip ids scores: list([anno1_score, anno2_score, anno3_score]), ctx_l: int max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. """ # indices inside rel_clip_ids scores = np.array(scores) # (#rel_clips, 3) agg_scores = np.sum(scores, 1) # (#rel_clips, ) sort_indices = np.argsort(agg_scores) # increasing # score_array = [min(agg_scores[idx], ctx_l-1) for idx in range(ctx_l)] score_array = np.zeros(ctx_l) max_len=ctx_l for idx in range(len(rel_clip_ids)): if rel_clip_ids[idx] >= ctx_l: max_len=max(max_len,rel_clip_ids[idx]) # score_array_new = np.zeros(ctx_l + 1) score_array_new = np.zeros(max_len+1) # score_array_new[:ctx_l] = score_array score_array_new[:len(score_array)] = score_array score_array = score_array_new score_array[rel_clip_ids[idx]] = agg_scores[idx] # indices in the whole video # the min(_, ctx_l-1) here is incorrect, but should not cause # much troubles since this should be rarely used. hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] easy_pos_clip_indices = [] easy_neg_clip_indices = [] if add_easy_negative: easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) if len(easy_neg_pool) >= max_n: easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) else: # copy the hard ones easy_pos_clip_indices = hard_pos_clip_indices easy_neg_clip_indices = hard_neg_clip_indices pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices return pos_clip_indices, neg_clip_indices, score_array def get_span_labels(self, windows, ctx_l): """ windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) Note a maximum of `self.max_windows` windows are used. returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length """ if len(windows) > self.max_windows: random.shuffle(windows) windows = windows[:self.max_windows] if self.span_loss_type == "l1": windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx windows = span_xx_to_cxw(windows) # normalized windows in cxw elif self.span_loss_type == "ce": windows = torch.Tensor([ [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] for w in windows]).long() # inclusive else: raise NotImplementedError return windows def _get_query_feat_by_qid(self, qid): # QVhighlight dataset q_feat_path = join(self.q_feat_dir, f"qid{qid}.pt") # q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) q_feat = torch.load(q_feat_path).numpy().astype(np.float32) if self.q_feat_type == "last_hidden_state": q_feat = q_feat[:self.max_q_l] if self.normalize_t: q_feat = l2_normalize_np_array(q_feat) if self.txt_drop_ratio > 0: q_feat = self.random_drop_rows(q_feat) return torch.from_numpy(q_feat) # (D, ) or (Lq, D) def random_drop_rows(self, embeddings): """randomly mask num_drop rows in embeddings to be zero. Args: embeddings: np.ndarray (L, D) """ num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) if num_drop_rows > 0: row_indices = np.random.choice( len(embeddings), size=num_drop_rows, replace=False) embeddings[row_indices] = 0 return embeddings def _get_video_feat_by_vid(self, vid): v_feat_list = [] for _feat_dir in self.v_feat_dirs: try: _feat_path = join(_feat_dir, f"{vid}.pt") _feat = torch.load(_feat_path)["features"][:self.max_v_l].numpy().astype(np.float32) except: _feat_path = join(_feat_dir, f"{vid}.pt") _feat = torch.load(_feat_path)[:self.max_v_l].numpy().astype(np.float32) if self.normalize_v: _feat = l2_normalize_np_array(_feat) v_feat_list.append(_feat) # some features are slightly longer than the others min_len = min([len(e) for e in v_feat_list]) v_feat_list = [e[:min_len] for e in v_feat_list] v_feat = np.concatenate(v_feat_list, axis=1) # (vlen=34, 768) return torch.from_numpy(v_feat) # (Lv, D) def start_end_collate(batch): batch_meta = [e["meta"] for e in batch] # seems no need to collate ? model_inputs_keys = batch[0]["model_inputs"].keys() batched_data = dict() for k in model_inputs_keys: if k == "span_labels": batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch] continue if k in ["saliency_pos_labels", "saliency_neg_labels"]: batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch]) continue if k == "saliency_all_labels": pad_data, mask_data = pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=np.float32, fixed_length=None) batched_data[k] = torch.tensor(pad_data, dtype=torch.float32) continue if k == 'qid': batched_data[k] = [e["model_inputs"][k] for e in batch] continue if k == 'vid': batched_data[k] = [e["model_inputs"][k] for e in batch] continue batched_data[k] = pad_sequences_1d( [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None) return batch_meta, batched_data def prepare_batch_inputs(batched_model_inputs): model_inputs = dict( src_txt=batched_model_inputs["query_feat"][0], src_txt_mask=batched_model_inputs["query_feat"][1], src_vid=batched_model_inputs["video_feat"][0], src_vid_mask=batched_model_inputs["video_feat"][1], vid=batched_model_inputs["vid"], qid=batched_model_inputs["qid"], ) targets = {} # import pdb; pdb.set_trace() if "span_labels" in batched_model_inputs: targets["span_labels"] = [ dict(spans=e["spans"]) for e in batched_model_inputs["span_labels"] ] if "saliency_pos_labels" in batched_model_inputs: for name in ["saliency_pos_labels", "saliency_neg_labels"]: targets[name] = batched_model_inputs[name] if "saliency_all_labels" in batched_model_inputs: targets["saliency_all_labels"] = batched_model_inputs["saliency_all_labels"] targets["relevant_clips"] = batched_model_inputs["saliency_all_labels"] targets = None if len(targets) == 0 else targets return model_inputs, targets