|
import pprint |
|
from tqdm import tqdm, trange |
|
import numpy as np |
|
import os |
|
from collections import OrderedDict, defaultdict |
|
from utils.basic_utils import AverageMeter |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.backends.cudnn as cudnn |
|
from torch.utils.data import DataLoader |
|
|
|
from cg_detr.config import TestOptions |
|
from cg_detr.model import build_model |
|
from cg_detr.span_utils import span_cxw_to_xx |
|
from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs |
|
from cg_detr.postprocessing_cg_detr import PostProcessorDETR |
|
from standalone_eval.eval import eval_submission |
|
from utils.basic_utils import save_jsonl, save_json |
|
from utils.temporal_nms import temporal_nms |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=logging.INFO) |
|
|
|
|
|
def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms): |
|
mr_res_after_nms = [] |
|
for e in mr_res: |
|
e["pred_relevant_windows"] = temporal_nms( |
|
e["pred_relevant_windows"][:max_before_nms], |
|
nms_thd=nms_thd, |
|
max_after_nms=max_after_nms |
|
) |
|
mr_res_after_nms.append(e) |
|
return mr_res_after_nms |
|
|
|
|
|
def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename): |
|
|
|
logger.info("Saving/Evaluating before nms results") |
|
submission_path = os.path.join(opt.results_dir, save_submission_filename) |
|
save_jsonl(submission, submission_path) |
|
|
|
if opt.eval_split_name in ["val"]: |
|
metrics = eval_submission( |
|
submission, gt_data, |
|
verbose=opt.debug, match_number=not opt.debug |
|
) |
|
save_metrics_path = submission_path.replace(".jsonl", "_metrics.json") |
|
save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False) |
|
latest_file_paths = [submission_path, save_metrics_path] |
|
else: |
|
metrics = None |
|
latest_file_paths = [submission_path, ] |
|
|
|
if opt.nms_thd != -1: |
|
logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd)) |
|
submission_after_nms = post_processing_mr_nms( |
|
submission, nms_thd=opt.nms_thd, |
|
max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms |
|
) |
|
|
|
logger.info("Saving/Evaluating nms results") |
|
submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd)) |
|
save_jsonl(submission_after_nms, submission_nms_path) |
|
if opt.eval_split_name == "val": |
|
metrics_nms = eval_submission( |
|
submission_after_nms, gt_data, |
|
verbose=opt.debug, match_number=not opt.debug |
|
) |
|
save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json") |
|
save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False) |
|
latest_file_paths += [submission_nms_path, save_metrics_nms_path] |
|
else: |
|
metrics_nms = None |
|
latest_file_paths = [submission_nms_path, ] |
|
else: |
|
metrics_nms = None |
|
return metrics, metrics_nms, latest_file_paths |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def compute_hl_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): |
|
model.eval() |
|
if criterion: |
|
assert eval_loader.dataset.load_labels |
|
criterion.eval() |
|
|
|
loss_meters = defaultdict(AverageMeter) |
|
write_tb = tb_writer is not None and epoch_i is not None |
|
|
|
mr_res = [] |
|
|
|
topk = 5 |
|
|
|
video_ap_collected = [] |
|
for batch in tqdm(eval_loader, desc="compute st ed scores"): |
|
query_meta = batch[0] |
|
|
|
model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) |
|
|
|
outputs = model(**model_inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preds = outputs['saliency_scores'].clone().detach() |
|
|
|
for meta, pred in zip(query_meta, preds): |
|
pred = pred |
|
label = meta['label'] |
|
|
|
video_ap = [] |
|
|
|
|
|
if opt.dset_name in ["tvsum"]: |
|
for i in range(20): |
|
pred=pred.cpu() |
|
cur_pred = pred[:len(label)] |
|
inds = torch.argsort(cur_pred, descending=True, dim=-1) |
|
|
|
|
|
cur_label = torch.Tensor(label)[:, i] |
|
cur_label = torch.where(cur_label > cur_label.median(), 1.0, .0) |
|
|
|
cur_label = cur_label[inds].tolist()[:topk] |
|
|
|
|
|
num_gt = sum(cur_label) |
|
if num_gt == 0: |
|
video_ap.append(0) |
|
continue |
|
|
|
hits = ap = rec = 0 |
|
prc = 1 |
|
|
|
for j, gt in enumerate(cur_label): |
|
hits += gt |
|
|
|
_rec = hits / num_gt |
|
_prc = hits / (j + 1) |
|
|
|
ap += (_rec - rec) * (prc + _prc) / 2 |
|
rec, prc = _rec, _prc |
|
|
|
video_ap.append(ap) |
|
|
|
elif opt.dset_name in ["youtube_uni"]: |
|
cur_pred = pred[:len(label)] |
|
|
|
cur_pred = cur_pred.cpu() |
|
inds = torch.argsort(cur_pred, descending=True, dim=-1) |
|
|
|
|
|
cur_label = torch.Tensor(label).squeeze()[inds].tolist() |
|
|
|
num_gt = sum(cur_label) |
|
if num_gt == 0: |
|
video_ap.append(0) |
|
continue |
|
|
|
hits = ap = rec = 0 |
|
prc = 1 |
|
|
|
for j, gt in enumerate(cur_label): |
|
hits += gt |
|
|
|
_rec = hits / num_gt |
|
_prc = hits / (j + 1) |
|
|
|
ap += (_rec - rec) * (prc + _prc) / 2 |
|
rec, prc = _rec, _prc |
|
|
|
video_ap.append(float(ap)) |
|
else: |
|
print("No such dataset") |
|
exit(-1) |
|
|
|
video_ap_collected.append(video_ap) |
|
|
|
mean_ap = np.mean(video_ap_collected) |
|
submmission = dict(mAP=round(mean_ap, 5)) |
|
|
|
|
|
|
|
if write_tb and criterion: |
|
for k, v in loss_meters.items(): |
|
tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) |
|
|
|
return submmission, loss_meters |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): |
|
model.eval() |
|
if criterion: |
|
assert eval_loader.dataset.load_labels |
|
criterion.eval() |
|
|
|
loss_meters = defaultdict(AverageMeter) |
|
write_tb = tb_writer is not None and epoch_i is not None |
|
|
|
mr_res = [] |
|
for batch in tqdm(eval_loader, desc="compute st ed scores"): |
|
query_meta = batch[0] |
|
|
|
model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) |
|
|
|
outputs = model(**model_inputs) |
|
prob = F.softmax(outputs["pred_logits"], -1) |
|
if opt.span_loss_type == "l1": |
|
scores = prob[..., 0] |
|
pred_spans = outputs["pred_spans"] |
|
_saliency_scores = outputs["saliency_scores"].half() |
|
saliency_scores = [] |
|
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() |
|
for j in range(len(valid_vid_lengths)): |
|
saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist()) |
|
else: |
|
bsz, n_queries = outputs["pred_spans"].shape[:2] |
|
pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l) |
|
pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) |
|
scores = torch.prod(pred_span_scores, 2) |
|
pred_spans[:, 1] += 1 |
|
pred_spans *= opt.clip_length |
|
|
|
|
|
for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())): |
|
if opt.span_loss_type == "l1": |
|
spans = span_cxw_to_xx(spans) * meta["duration"] |
|
spans = torch.clamp(spans, 0, meta["duration"]) |
|
|
|
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() |
|
if not opt.no_sort_results: |
|
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) |
|
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] |
|
cur_query_pred = dict( |
|
qid=meta["qid"], |
|
query=meta["query"], |
|
vid=meta["vid"], |
|
pred_relevant_windows=cur_ranked_preds, |
|
pred_saliency_scores=saliency_scores[idx] |
|
) |
|
mr_res.append(cur_query_pred) |
|
|
|
if criterion: |
|
loss_dict = criterion(outputs, targets) |
|
weight_dict = criterion.weight_dict |
|
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) |
|
loss_dict["loss_overall"] = float(losses) |
|
for k, v in loss_dict.items(): |
|
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) |
|
|
|
if opt.debug: |
|
break |
|
|
|
if write_tb and criterion: |
|
for k, v in loss_meters.items(): |
|
tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) |
|
|
|
if opt.dset_name in ['hl']: |
|
post_processor = PostProcessorDETR( |
|
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, |
|
min_w_l=2, max_w_l=150, move_window_method="left", |
|
process_func_names=("clip_ts", "round_multiple") |
|
) |
|
elif opt.dset_name in ['charadesSTA']: |
|
if opt.v_feat_dim == 4096: |
|
post_processor = PostProcessorDETR( |
|
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=360, |
|
min_w_l=12, max_w_l=360, move_window_method="left", |
|
process_func_names=("clip_ts", "round_multiple") |
|
) |
|
else: |
|
post_processor = PostProcessorDETR( |
|
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, |
|
min_w_l=2, max_w_l=60, move_window_method="left", |
|
process_func_names=("clip_ts", "round_multiple") |
|
) |
|
else: |
|
post_processor = PostProcessorDETR( |
|
clip_length=opt.clip_length, min_ts_val=0, max_ts_val=50000, |
|
min_w_l=0, max_w_l=50000, move_window_method="left", |
|
process_func_names=(["round_multiple"]) |
|
) |
|
|
|
mr_res = post_processor(mr_res) |
|
return mr_res, loss_meters |
|
|
|
|
|
def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer): |
|
"""compute and save query and video proposal embeddings""" |
|
eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) |
|
return eval_res, eval_loss_meters |
|
|
|
|
|
def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None): |
|
logger.info("Generate submissions") |
|
model.eval() |
|
if criterion is not None and eval_dataset.load_labels: |
|
criterion.eval() |
|
else: |
|
criterion = None |
|
|
|
if opt.dset_name == 'tacos': |
|
shuffle = True |
|
else: |
|
shuffle = False |
|
|
|
eval_loader = DataLoader( |
|
eval_dataset, |
|
collate_fn=start_end_collate, |
|
batch_size=opt.eval_bsz, |
|
num_workers=opt.num_workers, |
|
shuffle=shuffle, |
|
pin_memory=opt.pin_memory |
|
) |
|
|
|
|
|
|
|
if opt.dset_name in ['tvsum', 'youtube_uni']: |
|
metrics, eval_loss_meters = compute_hl_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) |
|
|
|
|
|
submission = [ |
|
{"brief": metrics} |
|
] |
|
submission_path = os.path.join(opt.results_dir, "latest_metric.jsonl") |
|
save_jsonl(submission, submission_path) |
|
|
|
return submission[0], submission[0], eval_loss_meters, [submission_path] |
|
|
|
else: |
|
submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer) |
|
|
|
if opt.dset_name in ['charadesSTA', 'tacos', 'nlq']: |
|
new_submission = [] |
|
for s in submission: |
|
s.pop('pred_saliency_scores', None) |
|
new_submission.append(s) |
|
submission = new_submission |
|
|
|
if opt.no_sort_results: |
|
save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl") |
|
metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing( |
|
submission, opt, eval_dataset.data, save_submission_filename) |
|
return metrics, metrics_nms, eval_loss_meters, latest_file_paths |
|
|
|
|
|
def setup_model(opt): |
|
"""setup model/optimizer/scheduler and load checkpoints when needed""" |
|
logger.info("setup model/optimizer/scheduler") |
|
model, criterion = build_model(opt) |
|
if opt.device.type == "cuda": |
|
logger.info("CUDA enabled.") |
|
model.to(opt.device) |
|
criterion.to(opt.device) |
|
|
|
param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}] |
|
optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd) |
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop) |
|
|
|
if opt.resume is not None: |
|
logger.info(f"Load checkpoint from {opt.resume}") |
|
checkpoint = torch.load(opt.resume, map_location="cpu") |
|
from collections import OrderedDict |
|
new_state_dict = OrderedDict() |
|
if 'pt' in opt.resume[:-4]: |
|
if 'asr' in opt.resume[:25]: |
|
model.load_state_dict(checkpoint["model"]) |
|
else: |
|
for k, v in checkpoint["model"].items(): |
|
name = k[7:] |
|
new_state_dict[name] = v |
|
|
|
model.load_state_dict(new_state_dict) |
|
else: |
|
model.load_state_dict(checkpoint["model"]) |
|
if opt.resume_all: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
opt.start_epoch = checkpoint['epoch'] + 1 |
|
logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}") |
|
else: |
|
logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path") |
|
|
|
return model, criterion, optimizer, lr_scheduler |
|
|
|
|
|
def start_inference(train_opt=None, split=None, splitfile=None): |
|
if train_opt is not None: |
|
opt = TestOptions().parse(train_opt.a_feat_dir) |
|
else: |
|
opt = TestOptions().parse() |
|
if split is not None: |
|
opt.eval_split_name = split |
|
if splitfile is not None: |
|
opt.eval_path = splitfile |
|
|
|
print(opt.eval_split_name) |
|
print(opt.eval_path) |
|
logger.info("Setup config, data and model...") |
|
|
|
|
|
cudnn.benchmark = True |
|
cudnn.deterministic = False |
|
|
|
assert opt.eval_path is not None |
|
if opt.eval_split_name == 'val': |
|
loadlabel = True |
|
else: |
|
loadlabel = False |
|
|
|
eval_dataset = StartEndDataset( |
|
dset_name=opt.dset_name, |
|
data_path=opt.eval_path, |
|
v_feat_dirs=opt.v_feat_dirs, |
|
q_feat_dir=opt.t_feat_dir, |
|
q_feat_type="last_hidden_state", |
|
max_q_l=opt.max_q_l, |
|
max_v_l=opt.max_v_l, |
|
ctx_mode=opt.ctx_mode, |
|
data_ratio=opt.data_ratio, |
|
normalize_v=not opt.no_norm_vfeat, |
|
normalize_t=not opt.no_norm_tfeat, |
|
clip_len=opt.clip_length, |
|
max_windows=opt.max_windows, |
|
load_labels=loadlabel, |
|
span_loss_type=opt.span_loss_type, |
|
txt_drop_ratio=0, |
|
dset_domain=opt.dset_domain, |
|
) |
|
|
|
|
|
|
|
model, criterion, _, _ = setup_model(opt) |
|
|
|
save_submission_filename = "hl_{}_submission.jsonl".format( |
|
opt.eval_split_name) |
|
|
|
|
|
logger.info("Starting inference...") |
|
with torch.no_grad(): |
|
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ |
|
eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion) |
|
if opt.eval_split_name == 'val': |
|
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) |
|
if metrics_nms is not None: |
|
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) |
|
|
|
from sys import argv |
|
if __name__ == '__main__': |
|
_,_,_,_,split,_,splitfile = argv |
|
|
|
start_inference(split=split, splitfile=splitfile) |
|
|