|
import os |
|
import time |
|
import json |
|
import pprint |
|
import random |
|
import numpy as np |
|
from tqdm import tqdm, trange |
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.backends.cudnn as cudnn |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
|
|
|
|
|
|
|
|
from cg_detr.config import BaseOptions |
|
from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs |
|
from cg_detr.inference import eval_epoch, start_inference, setup_model |
|
from utils.basic_utils import AverageMeter, dict_to_markdown |
|
from utils.model_utils import count_parameters |
|
|
|
|
|
import logging |
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=logging.INFO) |
|
|
|
|
|
def set_seed(seed, use_cuda=True): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if use_cuda: |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer): |
|
logger.info(f'[Epoch {epoch_i+1}]') |
|
model.train() |
|
criterion.train() |
|
|
|
|
|
time_meters = defaultdict(AverageMeter) |
|
loss_meters = defaultdict(AverageMeter) |
|
|
|
num_training_examples = len(train_loader) |
|
timer_dataloading = time.time() |
|
for batch_idx, batch in tqdm(enumerate(train_loader), |
|
desc="Training Iteration", |
|
total=num_training_examples): |
|
time_meters["dataloading_time"].update(time.time() - timer_dataloading) |
|
timer_start = time.time() |
|
model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) |
|
time_meters["prepare_inputs_time"].update(time.time() - timer_start) |
|
timer_start = time.time() |
|
|
|
outputs = model(**model_inputs, targets=targets) |
|
loss_dict = criterion(outputs, targets) |
|
weight_dict = criterion.weight_dict |
|
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) |
|
time_meters["model_forward_time"].update(time.time() - timer_start) |
|
|
|
timer_start = time.time() |
|
optimizer.zero_grad() |
|
losses.backward() |
|
if opt.grad_clip > 0: |
|
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) |
|
optimizer.step() |
|
time_meters["model_backward_time"].update(time.time() - timer_start) |
|
|
|
loss_dict["loss_overall"] = float(losses) |
|
for k, v in loss_dict.items(): |
|
loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) |
|
|
|
timer_dataloading = time.time() |
|
if opt.debug and batch_idx == 3: |
|
break |
|
|
|
|
|
tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) |
|
for k, v in loss_meters.items(): |
|
tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) |
|
|
|
to_write = opt.train_log_txt_formatter.format( |
|
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), |
|
epoch=epoch_i+1, |
|
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) |
|
with open(opt.train_log_filepath, "a") as f: |
|
f.write(to_write) |
|
|
|
logger.info("Epoch time stats:") |
|
for name, meter in time_meters.items(): |
|
d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} |
|
logger.info(f"{name} ==> {d}") |
|
|
|
|
|
def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt): |
|
if opt.device.type == "cuda": |
|
logger.info("CUDA enabled.") |
|
model.to(opt.device) |
|
|
|
tb_writer = SummaryWriter(opt.tensorboard_log_dir) |
|
tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) |
|
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" |
|
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" |
|
|
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
collate_fn=start_end_collate, |
|
batch_size=opt.bsz, |
|
num_workers=opt.num_workers, |
|
shuffle=True, |
|
pin_memory=opt.pin_memory |
|
) |
|
|
|
prev_best_score = 0. |
|
es_cnt = 0 |
|
|
|
if opt.start_epoch is None: |
|
start_epoch = -1 if opt.eval_untrained else 0 |
|
else: |
|
start_epoch = opt.start_epoch |
|
save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name) |
|
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): |
|
if epoch_i > -1: |
|
train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer) |
|
lr_scheduler.step() |
|
eval_epoch_interval = opt.eval_epoch |
|
if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: |
|
with torch.no_grad(): |
|
metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ |
|
eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer) |
|
|
|
|
|
to_write = opt.eval_log_txt_formatter.format( |
|
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), |
|
epoch=epoch_i, |
|
loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]), |
|
eval_metrics_str=json.dumps(metrics_no_nms)) |
|
|
|
with open(opt.eval_log_filepath, "a") as f: |
|
f.write(to_write) |
|
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) |
|
if metrics_nms is not None: |
|
logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) |
|
|
|
metrics = metrics_no_nms |
|
for k, v in metrics["brief"].items(): |
|
tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1) |
|
|
|
if opt.dset_name in ['hl']: |
|
stop_score = metrics["brief"]["MR-full-mAP"] |
|
else: |
|
stop_score = (metrics["brief"]["[email protected]"] + metrics["brief"]["[email protected]"]) / 2 |
|
|
|
|
|
if stop_score > prev_best_score: |
|
es_cnt = 0 |
|
prev_best_score = stop_score |
|
|
|
checkpoint = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
"epoch": epoch_i, |
|
"opt": opt |
|
} |
|
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt")) |
|
|
|
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] |
|
for src, tgt in zip(latest_file_paths, best_file_paths): |
|
os.renames(src, tgt) |
|
logger.info("The checkpoint file has been updated.") |
|
else: |
|
es_cnt += 1 |
|
if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: |
|
with open(opt.train_log_filepath, "a") as f: |
|
f.write(f"Early Stop at epoch {epoch_i}") |
|
logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n") |
|
break |
|
|
|
|
|
checkpoint = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
"epoch": epoch_i, |
|
"opt": opt |
|
} |
|
torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if opt.debug: |
|
break |
|
|
|
tb_writer.close() |
|
|
|
|
|
|
|
def start_training(): |
|
logger.info("Setup config, data and model...") |
|
opt = BaseOptions().parse() |
|
set_seed(opt.seed) |
|
if opt.debug: |
|
|
|
|
|
cudnn.benchmark = False |
|
cudnn.deterministic = True |
|
|
|
|
|
dataset_config = dict( |
|
dset_name=opt.dset_name, |
|
data_path=opt.train_path, |
|
v_feat_dirs=opt.v_feat_dirs, |
|
q_feat_dir=opt.t_feat_dir, |
|
q_feat_type="last_hidden_state", |
|
max_q_l=opt.max_q_l, |
|
max_v_l=opt.max_v_l, |
|
ctx_mode=opt.ctx_mode, |
|
data_ratio=opt.data_ratio, |
|
normalize_v=not opt.no_norm_vfeat, |
|
normalize_t=not opt.no_norm_tfeat, |
|
clip_len=opt.clip_length, |
|
max_windows=opt.max_windows, |
|
span_loss_type=opt.span_loss_type, |
|
txt_drop_ratio=opt.txt_drop_ratio, |
|
dset_domain=opt.dset_domain, |
|
) |
|
dataset_config["data_path"] = opt.train_path |
|
train_dataset = StartEndDataset(**dataset_config) |
|
|
|
|
|
|
|
if opt.eval_path is not None: |
|
dataset_config["data_path"] = opt.eval_path |
|
dataset_config["txt_drop_ratio"] = 0 |
|
dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("sub_features", "text_features") |
|
|
|
|
|
eval_dataset = StartEndDataset(**dataset_config) |
|
|
|
else: |
|
eval_dataset = None |
|
|
|
model, criterion, optimizer, lr_scheduler = setup_model(opt) |
|
logger.info(f"Model {model}") |
|
count_parameters(model) |
|
logger.info("Start Training...") |
|
|
|
train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt) |
|
|
|
return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug, opt |
|
|
|
|
|
if __name__ == '__main__': |
|
best_ckpt_path, eval_split_name, eval_path, debug, opt = start_training() |
|
if not debug: |
|
input_args = ["--resume", best_ckpt_path, |
|
"--eval_split_name", eval_split_name, |
|
"--eval_path", eval_path] |
|
|
|
import sys |
|
sys.argv[1:] = input_args |
|
logger.info("\n\n\nFINISHED TRAINING!!!") |
|
logger.info("Evaluating model at {}".format(best_ckpt_path)) |
|
logger.info("Input args {}".format(sys.argv[1:])) |
|
start_inference(opt) |
|
|