Spaces:
Runtime error
Runtime error
First commit
Browse files- VQ-Trans/checkpoints/train_vq.py +171 -0
VQ-Trans/checkpoints/train_vq.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.optim as optim
|
6 |
+
from torch.utils.tensorboard import SummaryWriter
|
7 |
+
|
8 |
+
import models.vqvae as vqvae
|
9 |
+
import utils.losses as losses
|
10 |
+
import options.option_vq as option_vq
|
11 |
+
import utils.utils_model as utils_model
|
12 |
+
from dataset import dataset_VQ, dataset_TM_eval
|
13 |
+
import utils.eval_trans as eval_trans
|
14 |
+
from options.get_eval_option import get_opt
|
15 |
+
from models.evaluator_wrapper import EvaluatorModelWrapper
|
16 |
+
import warnings
|
17 |
+
warnings.filterwarnings('ignore')
|
18 |
+
from utils.word_vectorizer import WordVectorizer
|
19 |
+
|
20 |
+
def update_lr_warm_up(optimizer, nb_iter, warm_up_iter, lr):
|
21 |
+
|
22 |
+
current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
|
23 |
+
for param_group in optimizer.param_groups:
|
24 |
+
param_group["lr"] = current_lr
|
25 |
+
|
26 |
+
return optimizer, current_lr
|
27 |
+
|
28 |
+
##### ---- Exp dirs ---- #####
|
29 |
+
args = option_vq.get_args_parser()
|
30 |
+
torch.manual_seed(args.seed)
|
31 |
+
|
32 |
+
args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}')
|
33 |
+
os.makedirs(args.out_dir, exist_ok = True)
|
34 |
+
|
35 |
+
##### ---- Logger ---- #####
|
36 |
+
logger = utils_model.get_logger(args.out_dir)
|
37 |
+
writer = SummaryWriter(args.out_dir)
|
38 |
+
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
w_vectorizer = WordVectorizer('./glove', 'our_vab')
|
43 |
+
|
44 |
+
if args.dataname == 'kit' :
|
45 |
+
dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt'
|
46 |
+
args.nb_joints = 21
|
47 |
+
|
48 |
+
else :
|
49 |
+
dataset_opt_path = 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
|
50 |
+
args.nb_joints = 22
|
51 |
+
|
52 |
+
logger.info(f'Training on {args.dataname}, motions are with {args.nb_joints} joints')
|
53 |
+
|
54 |
+
wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
|
55 |
+
eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
|
56 |
+
|
57 |
+
|
58 |
+
##### ---- Dataloader ---- #####
|
59 |
+
train_loader = dataset_VQ.DATALoader(args.dataname,
|
60 |
+
args.batch_size,
|
61 |
+
window_size=args.window_size,
|
62 |
+
unit_length=2**args.down_t)
|
63 |
+
|
64 |
+
train_loader_iter = dataset_VQ.cycle(train_loader)
|
65 |
+
|
66 |
+
val_loader = dataset_TM_eval.DATALoader(args.dataname, False,
|
67 |
+
32,
|
68 |
+
w_vectorizer,
|
69 |
+
unit_length=2**args.down_t)
|
70 |
+
|
71 |
+
##### ---- Network ---- #####
|
72 |
+
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
|
73 |
+
args.nb_code,
|
74 |
+
args.code_dim,
|
75 |
+
args.output_emb_width,
|
76 |
+
args.down_t,
|
77 |
+
args.stride_t,
|
78 |
+
args.width,
|
79 |
+
args.depth,
|
80 |
+
args.dilation_growth_rate,
|
81 |
+
args.vq_act,
|
82 |
+
args.vq_norm)
|
83 |
+
|
84 |
+
|
85 |
+
if args.resume_pth :
|
86 |
+
logger.info('loading checkpoint from {}'.format(args.resume_pth))
|
87 |
+
ckpt = torch.load(args.resume_pth, map_location='cpu')
|
88 |
+
net.load_state_dict(ckpt['net'], strict=True)
|
89 |
+
net.train()
|
90 |
+
net.cuda()
|
91 |
+
|
92 |
+
##### ---- Optimizer & Scheduler ---- #####
|
93 |
+
optimizer = optim.AdamW(net.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
|
94 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=args.gamma)
|
95 |
+
|
96 |
+
|
97 |
+
Loss = losses.ReConsLoss(args.recons_loss, args.nb_joints)
|
98 |
+
|
99 |
+
##### ------ warm-up ------- #####
|
100 |
+
avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
|
101 |
+
|
102 |
+
for nb_iter in range(1, args.warm_up_iter):
|
103 |
+
|
104 |
+
optimizer, current_lr = update_lr_warm_up(optimizer, nb_iter, args.warm_up_iter, args.lr)
|
105 |
+
|
106 |
+
gt_motion = next(train_loader_iter)
|
107 |
+
gt_motion = gt_motion.cuda().float() # (bs, 64, dim)
|
108 |
+
|
109 |
+
pred_motion, loss_commit, perplexity = net(gt_motion)
|
110 |
+
loss_motion = Loss(pred_motion, gt_motion)
|
111 |
+
loss_vel = Loss.forward_vel(pred_motion, gt_motion)
|
112 |
+
|
113 |
+
loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
|
114 |
+
|
115 |
+
optimizer.zero_grad()
|
116 |
+
loss.backward()
|
117 |
+
optimizer.step()
|
118 |
+
|
119 |
+
avg_recons += loss_motion.item()
|
120 |
+
avg_perplexity += perplexity.item()
|
121 |
+
avg_commit += loss_commit.item()
|
122 |
+
|
123 |
+
if nb_iter % args.print_iter == 0 :
|
124 |
+
avg_recons /= args.print_iter
|
125 |
+
avg_perplexity /= args.print_iter
|
126 |
+
avg_commit /= args.print_iter
|
127 |
+
|
128 |
+
logger.info(f"Warmup. Iter {nb_iter} : lr {current_lr:.5f} \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}")
|
129 |
+
|
130 |
+
avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
|
131 |
+
|
132 |
+
##### ---- Training ---- #####
|
133 |
+
avg_recons, avg_perplexity, avg_commit = 0., 0., 0.
|
134 |
+
best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, 0, best_fid=1000, best_iter=0, best_div=100, best_top1=0, best_top2=0, best_top3=0, best_matching=100, eval_wrapper=eval_wrapper)
|
135 |
+
|
136 |
+
for nb_iter in range(1, args.total_iter + 1):
|
137 |
+
|
138 |
+
gt_motion = next(train_loader_iter)
|
139 |
+
gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len
|
140 |
+
|
141 |
+
pred_motion, loss_commit, perplexity = net(gt_motion)
|
142 |
+
loss_motion = Loss(pred_motion, gt_motion)
|
143 |
+
loss_vel = Loss.forward_vel(pred_motion, gt_motion)
|
144 |
+
|
145 |
+
loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel
|
146 |
+
|
147 |
+
optimizer.zero_grad()
|
148 |
+
loss.backward()
|
149 |
+
optimizer.step()
|
150 |
+
scheduler.step()
|
151 |
+
|
152 |
+
avg_recons += loss_motion.item()
|
153 |
+
avg_perplexity += perplexity.item()
|
154 |
+
avg_commit += loss_commit.item()
|
155 |
+
|
156 |
+
if nb_iter % args.print_iter == 0 :
|
157 |
+
avg_recons /= args.print_iter
|
158 |
+
avg_perplexity /= args.print_iter
|
159 |
+
avg_commit /= args.print_iter
|
160 |
+
|
161 |
+
writer.add_scalar('./Train/L1', avg_recons, nb_iter)
|
162 |
+
writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter)
|
163 |
+
writer.add_scalar('./Train/Commit', avg_commit, nb_iter)
|
164 |
+
|
165 |
+
logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}")
|
166 |
+
|
167 |
+
avg_recons, avg_perplexity, avg_commit = 0., 0., 0.,
|
168 |
+
|
169 |
+
if nb_iter % args.eval_iter==0 :
|
170 |
+
best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper)
|
171 |
+
|