|
import torch |
|
from torch.utils.data import Dataset |
|
import numpy as np |
|
from tqdm import tqdm |
|
import random |
|
import logging |
|
from os.path import join, exists |
|
from third_party.cgdetr.utils.basic_utils import load_jsonl, l2_normalize_np_array |
|
from third_party.cgdetr.utils.tensor_utils import pad_sequences_1d |
|
from third_party.cgdetr.cg_detr.span_utils import span_xx_to_cxw |
|
|
|
import torch.nn as nn |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class StartEndDataset(Dataset): |
|
Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] |
|
"""One line in data loaded from data_path." |
|
{ |
|
"qid": 7803, |
|
"query": "Man in gray top walks from outside to inside.", |
|
"duration": 150, |
|
"vid": "RoripwjYFp8_360.0_510.0", |
|
"relevant_clip_ids": [13, 14, 15, 16, 17], |
|
"relevant_windows": [[26, 36]] |
|
} |
|
""" |
|
|
|
def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, |
|
q_feat_type="last_hidden_state", |
|
max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", |
|
normalize_v=True, normalize_t=True, load_labels=True, |
|
clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0, |
|
dset_domain=None): |
|
self.dset_name = dset_name |
|
self.data_path = data_path |
|
self.data_ratio = data_ratio |
|
self.v_feat_dirs = v_feat_dirs \ |
|
if isinstance(v_feat_dirs, list) else [v_feat_dirs] |
|
self.q_feat_dir = q_feat_dir |
|
self.q_feat_type = q_feat_type |
|
if max_v_l == -1: |
|
max_v_l = 100000000 |
|
if max_q_l == -1: |
|
max_q_l = 100 |
|
self.max_q_l = max_q_l |
|
self.max_v_l = max_v_l |
|
self.ctx_mode = ctx_mode |
|
self.use_tef = "tef" in ctx_mode |
|
self.use_video = "video" in ctx_mode |
|
self.normalize_t = normalize_t |
|
self.normalize_v = normalize_v |
|
self.load_labels = load_labels |
|
self.clip_len = clip_len |
|
self.max_windows = max_windows |
|
self.span_loss_type = span_loss_type |
|
self.txt_drop_ratio = txt_drop_ratio |
|
if "val" in data_path or "test" in data_path: |
|
assert txt_drop_ratio == 0 |
|
|
|
if self.dset_name == 'hl': |
|
self.max_q_l = 32 |
|
self.max_v_l = 75 |
|
self.clip_len = 2 |
|
|
|
|
|
assert q_feat_type in self.Q_FEAT_TYPES |
|
|
|
|
|
self.data = self.load_data() |
|
|
|
self.use_glove = False |
|
self.use_glove = 'vgg' in self.v_feat_dirs[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(self): |
|
datalist = load_jsonl(self.data_path) |
|
if self.data_ratio != 1: |
|
n_examples = int(len(datalist) * self.data_ratio) |
|
datalist = datalist[:n_examples] |
|
logger.info("Using {}% of the data: {} examples" |
|
.format(self.data_ratio * 100, n_examples)) |
|
return datalist |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, index): |
|
meta = self.data[index] |
|
|
|
model_inputs = dict() |
|
|
|
if self.use_glove: |
|
model_inputs["query_feat"] = self.get_query(meta["query"]) |
|
else: |
|
model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) |
|
|
|
|
|
if self.use_video : |
|
model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) |
|
ctx_l = len(model_inputs["video_feat"]) |
|
else: |
|
ctx_l = self.max_v_l |
|
|
|
|
|
if self.use_tef: |
|
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l |
|
tef_ed = tef_st + 1.0 / ctx_l |
|
tef = torch.stack([tef_st, tef_ed], dim=1) |
|
if self.use_video : |
|
model_inputs["video_feat"] = torch.cat( |
|
[model_inputs["video_feat"], tef], dim=1) |
|
else: |
|
model_inputs["video_feat"] = tef |
|
|
|
|
|
|
|
if "relevant_windows" in meta: |
|
model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) |
|
if self.dset_name in ['charadesSTA', 'tacos', 'activitynet']: |
|
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ |
|
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) |
|
elif "subs_train" not in self.data_path: |
|
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ |
|
self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) |
|
else: |
|
model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ |
|
self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) |
|
|
|
if 'qvhighlight' or 'qvhl' in self.data_path: |
|
model_inputs["relevant_clip_ids"] = meta["relevant_clip_ids"] |
|
model_inputs["vid"] = meta["vid"] |
|
model_inputs["qid"] = meta["qid"] |
|
return dict(meta=meta, model_inputs=model_inputs) |
|
|
|
|
|
|
|
|
|
|
|
def get_query(self, query): |
|
print("ERROR") |
|
exit() |
|
|
|
def get_saliency_labels_sub_as_query(self, gt_window, duration, ctx_l, max_n=2): |
|
clip_len = duration / ctx_l |
|
gt_st = int(gt_window[0] / clip_len) |
|
gt_ed = max(0, min(int(gt_window[1] / clip_len), ctx_l) - 1) |
|
if gt_st > gt_ed: |
|
gt_st = gt_ed |
|
|
|
if gt_st != gt_ed: |
|
pos_clip_indices = random.sample(range(gt_st, gt_ed + 1), k=max_n) |
|
else: |
|
if self.dset_name == 'nlq': |
|
pos_clip_indices = [gt_st] * 2 |
|
else: |
|
pos_clip_indices = [gt_st, gt_st] |
|
|
|
neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) |
|
try: |
|
neg_clip_indices = random.sample(neg_pool, k=max_n) |
|
except: |
|
neg_clip_indices = pos_clip_indices |
|
|
|
|
|
score_array = np.zeros(ctx_l) |
|
score_array[gt_st:gt_ed + 1] = 1 |
|
|
|
return pos_clip_indices, neg_clip_indices, score_array |
|
|
|
|
|
def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True): |
|
"""Sum the scores from the three annotations, then take the two clips with the |
|
maximum scores as positive, and two with the minimum scores as negative. |
|
Args: |
|
rel_clip_ids: list(int), list of relevant clip ids |
|
scores: list([anno1_score, anno2_score, anno3_score]), |
|
ctx_l: int |
|
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. |
|
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. |
|
""" |
|
|
|
scores = np.array(scores) |
|
agg_scores = np.sum(scores, 1) |
|
sort_indices = np.argsort(agg_scores) |
|
|
|
|
|
|
|
|
|
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] |
|
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] |
|
easy_pos_clip_indices = [] |
|
easy_neg_clip_indices = [] |
|
if add_easy_negative: |
|
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) |
|
if len(easy_neg_pool) >= max_n: |
|
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) |
|
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) |
|
else: |
|
easy_pos_clip_indices = hard_pos_clip_indices |
|
easy_neg_clip_indices = hard_neg_clip_indices |
|
|
|
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices |
|
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices |
|
return pos_clip_indices, neg_clip_indices |
|
|
|
def get_saliency_labels_all(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True): |
|
"""Sum the scores from the three annotations, then take the two clips with the |
|
maximum scores as positive, and two with the minimum scores as negative. |
|
Args: |
|
rel_clip_ids: list(int), list of relevant clip ids |
|
scores: list([anno1_score, anno2_score, anno3_score]), |
|
ctx_l: int |
|
max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. |
|
add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. |
|
""" |
|
|
|
scores = np.array(scores) |
|
agg_scores = np.sum(scores, 1) |
|
sort_indices = np.argsort(agg_scores) |
|
|
|
|
|
score_array = np.zeros(ctx_l) |
|
max_len=ctx_l |
|
for idx in range(len(rel_clip_ids)): |
|
if rel_clip_ids[idx] >= ctx_l: |
|
max_len=max(max_len,rel_clip_ids[idx]) |
|
|
|
score_array_new = np.zeros(max_len+1) |
|
|
|
score_array_new[:len(score_array)] = score_array |
|
score_array = score_array_new |
|
score_array[rel_clip_ids[idx]] = agg_scores[idx] |
|
|
|
|
|
|
|
|
|
hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] |
|
hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] |
|
easy_pos_clip_indices = [] |
|
easy_neg_clip_indices = [] |
|
if add_easy_negative: |
|
easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) |
|
if len(easy_neg_pool) >= max_n: |
|
easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) |
|
easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) |
|
else: |
|
easy_pos_clip_indices = hard_pos_clip_indices |
|
easy_neg_clip_indices = hard_neg_clip_indices |
|
|
|
pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices |
|
neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices |
|
return pos_clip_indices, neg_clip_indices, score_array |
|
|
|
def get_span_labels(self, windows, ctx_l): |
|
""" |
|
windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) |
|
Note a maximum of `self.max_windows` windows are used. |
|
returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length |
|
""" |
|
if len(windows) > self.max_windows: |
|
random.shuffle(windows) |
|
windows = windows[:self.max_windows] |
|
if self.span_loss_type == "l1": |
|
windows = torch.Tensor(windows) / (ctx_l * self.clip_len) |
|
windows = span_xx_to_cxw(windows) |
|
elif self.span_loss_type == "ce": |
|
windows = torch.Tensor([ |
|
[int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] |
|
for w in windows]).long() |
|
else: |
|
raise NotImplementedError |
|
return windows |
|
|
|
def _get_query_feat_by_qid(self, qid): |
|
|
|
q_feat_path = join(self.q_feat_dir, f"qid{qid}.pt") |
|
|
|
q_feat = torch.load(q_feat_path).numpy().astype(np.float32) |
|
if self.q_feat_type == "last_hidden_state": |
|
q_feat = q_feat[:self.max_q_l] |
|
if self.normalize_t: |
|
q_feat = l2_normalize_np_array(q_feat) |
|
if self.txt_drop_ratio > 0: |
|
q_feat = self.random_drop_rows(q_feat) |
|
return torch.from_numpy(q_feat) |
|
|
|
def random_drop_rows(self, embeddings): |
|
"""randomly mask num_drop rows in embeddings to be zero. |
|
Args: |
|
embeddings: np.ndarray (L, D) |
|
""" |
|
num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) |
|
if num_drop_rows > 0: |
|
row_indices = np.random.choice( |
|
len(embeddings), size=num_drop_rows, replace=False) |
|
embeddings[row_indices] = 0 |
|
return embeddings |
|
|
|
def _get_video_feat_by_vid(self, vid): |
|
v_feat_list = [] |
|
for _feat_dir in self.v_feat_dirs: |
|
try: |
|
_feat_path = join(_feat_dir, f"{vid}.pt") |
|
_feat = torch.load(_feat_path)["features"][:self.max_v_l].numpy().astype(np.float32) |
|
except: |
|
_feat_path = join(_feat_dir, f"{vid}.pt") |
|
_feat = torch.load(_feat_path)[:self.max_v_l].numpy().astype(np.float32) |
|
if self.normalize_v: |
|
_feat = l2_normalize_np_array(_feat) |
|
v_feat_list.append(_feat) |
|
|
|
min_len = min([len(e) for e in v_feat_list]) |
|
v_feat_list = [e[:min_len] for e in v_feat_list] |
|
v_feat = np.concatenate(v_feat_list, axis=1) |
|
return torch.from_numpy(v_feat) |
|
|
|
|
|
|
|
def start_end_collate(batch): |
|
batch_meta = [e["meta"] for e in batch] |
|
|
|
model_inputs_keys = batch[0]["model_inputs"].keys() |
|
batched_data = dict() |
|
for k in model_inputs_keys: |
|
if k == "span_labels": |
|
batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch] |
|
continue |
|
if k in ["saliency_pos_labels", "saliency_neg_labels"]: |
|
batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch]) |
|
continue |
|
if k == "saliency_all_labels": |
|
pad_data, mask_data = pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=np.float32, fixed_length=None) |
|
batched_data[k] = torch.tensor(pad_data, dtype=torch.float32) |
|
continue |
|
if k == 'qid': |
|
batched_data[k] = [e["model_inputs"][k] for e in batch] |
|
continue |
|
if k == 'vid': |
|
batched_data[k] = [e["model_inputs"][k] for e in batch] |
|
continue |
|
batched_data[k] = pad_sequences_1d( |
|
[e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None) |
|
return batch_meta, batched_data |
|
|
|
|
|
def prepare_batch_inputs(batched_model_inputs): |
|
model_inputs = dict( |
|
src_txt=batched_model_inputs["query_feat"][0], |
|
src_txt_mask=batched_model_inputs["query_feat"][1], |
|
src_vid=batched_model_inputs["video_feat"][0], |
|
src_vid_mask=batched_model_inputs["video_feat"][1], |
|
vid=batched_model_inputs["vid"], |
|
qid=batched_model_inputs["qid"], |
|
) |
|
targets = {} |
|
|
|
|
|
|
|
if "span_labels" in batched_model_inputs: |
|
targets["span_labels"] = [ |
|
dict(spans=e["spans"]) |
|
for e in batched_model_inputs["span_labels"] |
|
] |
|
if "saliency_pos_labels" in batched_model_inputs: |
|
for name in ["saliency_pos_labels", "saliency_neg_labels"]: |
|
targets[name] = batched_model_inputs[name] |
|
|
|
if "saliency_all_labels" in batched_model_inputs: |
|
targets["saliency_all_labels"] = batched_model_inputs["saliency_all_labels"] |
|
targets["relevant_clips"] = batched_model_inputs["saliency_all_labels"] |
|
targets = None if len(targets) == 0 else targets |
|
return model_inputs, targets |
|
|