|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
import argparse |
|
import json |
|
import math |
|
import os |
|
import random |
|
import signal |
|
import subprocess |
|
import sys |
|
import time |
|
import numpy as np |
|
import wandb |
|
|
|
from PIL import Image, ImageOps, ImageFilter |
|
from torch import nn, optim |
|
import torch |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
|
|
parser = argparse.ArgumentParser(description='Barlow Twins Training') |
|
parser.add_argument('data', type=Path, metavar='DIR', |
|
help='path to dataset') |
|
parser.add_argument('--workers', default=8, type=int, metavar='N', |
|
help='number of data loader workers') |
|
parser.add_argument('--epochs', default=300, type=int, metavar='N', |
|
help='number of total epochs to run') |
|
parser.add_argument('--batch-size', default=512, type=int, metavar='N', |
|
help='mini-batch size') |
|
parser.add_argument('--learning-rate-weights', default=0.2, type=float, metavar='LR', |
|
help='base learning rate for weights') |
|
parser.add_argument('--learning-rate-biases', default=0.0048, type=float, metavar='LR', |
|
help='base learning rate for biases and batch norm parameters') |
|
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W', |
|
help='weight decay') |
|
parser.add_argument('--lambd', default=0.0051, type=float, metavar='L', |
|
help='weight on off-diagonal terms') |
|
parser.add_argument('--projector', default='8192-8192-8192', type=str, |
|
metavar='MLP', help='projector MLP') |
|
parser.add_argument('--print-freq', default=1, type=int, metavar='N', |
|
help='print frequency') |
|
parser.add_argument('--checkpoint-dir', default='/mnt/store/wbandar1/projects/ssl-aug-artifacts/', type=Path, |
|
metavar='DIR', help='path to checkpoint directory') |
|
parser.add_argument('--is_mixup', default='false', type=str, |
|
metavar='L', help='mixup regularization', choices=['true', 'false']) |
|
parser.add_argument('--lambda_mixup', default=0.1, type=float, metavar='L', |
|
help='Hyperparamter for the regularization loss') |
|
|
|
def main(): |
|
args = parser.parse_args() |
|
args.is_mixup = args.is_mixup.lower() == 'true' |
|
args.ngpus_per_node = torch.cuda.device_count() |
|
|
|
run = wandb.init(project="Barlow-Twins-MixUp-ImageNet", config=args, dir='/mnt/store/wbandar1/projects/ssl-aug-artifacts/wandb_logs/') |
|
run_id = wandb.run.id |
|
args.checkpoint_dir=Path(os.path.join(args.checkpoint_dir, run_id)) |
|
|
|
if 'SLURM_JOB_ID' in os.environ: |
|
|
|
|
|
signal.signal(signal.SIGUSR1, handle_sigusr1) |
|
signal.signal(signal.SIGTERM, handle_sigterm) |
|
|
|
|
|
cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST') |
|
stdout = subprocess.check_output(cmd.split()) |
|
host_name = stdout.decode().splitlines()[0] |
|
args.rank = int(os.getenv('SLURM_NODEID')) * args.ngpus_per_node |
|
args.world_size = int(os.getenv('SLURM_NNODES')) * args.ngpus_per_node |
|
args.dist_url = f'tcp://{host_name}:58472' |
|
else: |
|
|
|
args.rank = 0 |
|
args.dist_url = 'tcp://localhost:58472' |
|
args.world_size = args.ngpus_per_node |
|
torch.multiprocessing.spawn(main_worker, (args,run,), args.ngpus_per_node) |
|
wandb.finish() |
|
|
|
|
|
def main_worker(gpu, args, run): |
|
args.rank += gpu |
|
torch.distributed.init_process_group( |
|
backend='nccl', init_method=args.dist_url, |
|
world_size=args.world_size, rank=args.rank) |
|
|
|
if args.rank == 0: |
|
args.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1) |
|
print(' '.join(sys.argv)) |
|
print(' '.join(sys.argv), file=stats_file) |
|
|
|
torch.cuda.set_device(gpu) |
|
torch.backends.cudnn.benchmark = True |
|
|
|
model = BarlowTwins(args).cuda(gpu) |
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
param_weights = [] |
|
param_biases = [] |
|
for param in model.parameters(): |
|
if param.ndim == 1: |
|
param_biases.append(param) |
|
else: |
|
param_weights.append(param) |
|
parameters = [{'params': param_weights}, {'params': param_biases}] |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) |
|
optimizer = LARS(parameters, lr=0, weight_decay=args.weight_decay, |
|
weight_decay_filter=True, |
|
lars_adaptation_filter=True) |
|
|
|
|
|
if (args.checkpoint_dir / 'checkpoint.pth').is_file(): |
|
ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth', |
|
map_location='cpu') |
|
start_epoch = ckpt['epoch'] |
|
model.load_state_dict(ckpt['model']) |
|
optimizer.load_state_dict(ckpt['optimizer']) |
|
else: |
|
start_epoch = 0 |
|
|
|
dataset = torchvision.datasets.ImageFolder(args.data / 'train', Transform()) |
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
|
assert args.batch_size % args.world_size == 0 |
|
per_device_batch_size = args.batch_size // args.world_size |
|
loader = torch.utils.data.DataLoader( |
|
dataset, batch_size=per_device_batch_size, num_workers=args.workers, |
|
pin_memory=True, sampler=sampler) |
|
|
|
start_time = time.time() |
|
scaler = torch.cuda.amp.GradScaler(growth_interval=100, enabled=True) |
|
for epoch in range(start_epoch, args.epochs): |
|
sampler.set_epoch(epoch) |
|
for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)): |
|
y1 = y1.cuda(gpu, non_blocking=True) |
|
y2 = y2.cuda(gpu, non_blocking=True) |
|
adjust_learning_rate(args, optimizer, loader, step) |
|
mixup_loss_scale = adjust_mixup_scale(loader, step, args.lambda_mixup) |
|
optimizer.zero_grad() |
|
with torch.cuda.amp.autocast(enabled=True): |
|
loss_bt, loss_reg = model(y1, y2, args.is_mixup) |
|
loss_regs = mixup_loss_scale * loss_reg |
|
loss = loss_bt + loss_regs |
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
if step % args.print_freq == 0: |
|
if args.rank == 0: |
|
stats = dict(epoch=epoch, step=step, |
|
lr_weights=optimizer.param_groups[0]['lr'], |
|
lr_biases=optimizer.param_groups[1]['lr'], |
|
loss=loss.item(), |
|
time=int(time.time() - start_time)) |
|
print(json.dumps(stats)) |
|
print(json.dumps(stats), file=stats_file) |
|
if args.is_mixup: |
|
run.log( |
|
{ |
|
"epoch": epoch, |
|
"step": step, |
|
"lr_weights": optimizer.param_groups[0]['lr'], |
|
"lr_biases": optimizer.param_groups[1]['lr'], |
|
"loss": loss.item(), |
|
"loss_bt": loss_bt.item(), |
|
"loss_reg(unscaled)": loss_reg.item(), |
|
"reg_scale": mixup_loss_scale, |
|
"loss_reg(scaled)": loss_regs.item(), |
|
"time": int(time.time() - start_time)} |
|
) |
|
else: |
|
run.log( |
|
{ |
|
"epoch": epoch, |
|
"step": step, |
|
"lr_weights": optimizer.param_groups[0]['lr'], |
|
"lr_biases": optimizer.param_groups[1]['lr'], |
|
"loss": loss.item(), |
|
"loss_bt": loss.item(), |
|
"loss_reg(unscaled)": 0., |
|
"reg_scale": 0., |
|
"loss_reg(scaled)": 0., |
|
"time": int(time.time() - start_time)} |
|
) |
|
if args.rank == 0: |
|
|
|
state = dict(epoch=epoch + 1, model=model.state_dict(), |
|
optimizer=optimizer.state_dict()) |
|
torch.save(state, args.checkpoint_dir / 'checkpoint.pth') |
|
if args.rank == 0: |
|
|
|
print("Saving final model ...") |
|
torch.save(model.module.backbone.state_dict(), |
|
args.checkpoint_dir / 'resnet50.pth') |
|
print("Finished saving final model ...") |
|
|
|
|
|
def adjust_learning_rate(args, optimizer, loader, step): |
|
max_steps = args.epochs * len(loader) |
|
warmup_steps = 10 * len(loader) |
|
base_lr = args.batch_size / 256 |
|
if step < warmup_steps: |
|
lr = base_lr * step / warmup_steps |
|
else: |
|
step -= warmup_steps |
|
max_steps -= warmup_steps |
|
q = 0.5 * (1 + math.cos(math.pi * step / max_steps)) |
|
end_lr = base_lr * 0.001 |
|
lr = base_lr * q + end_lr * (1 - q) |
|
optimizer.param_groups[0]['lr'] = lr * args.learning_rate_weights |
|
optimizer.param_groups[1]['lr'] = lr * args.learning_rate_biases |
|
|
|
def adjust_mixup_scale(loader, step, lambda_mixup): |
|
warmup_steps = 10 * len(loader) |
|
if step < warmup_steps: |
|
return lambda_mixup * step / warmup_steps |
|
else: |
|
return lambda_mixup |
|
|
|
def handle_sigusr1(signum, frame): |
|
os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}') |
|
exit() |
|
|
|
|
|
def handle_sigterm(signum, frame): |
|
pass |
|
|
|
|
|
def off_diagonal(x): |
|
|
|
n, m = x.shape |
|
assert n == m |
|
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() |
|
|
|
|
|
class BarlowTwins(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.backbone = torchvision.models.resnet50(zero_init_residual=True) |
|
self.backbone.fc = nn.Identity() |
|
|
|
|
|
sizes = [2048] + list(map(int, args.projector.split('-'))) |
|
layers = [] |
|
for i in range(len(sizes) - 2): |
|
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) |
|
layers.append(nn.BatchNorm1d(sizes[i + 1])) |
|
layers.append(nn.ReLU(inplace=True)) |
|
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) |
|
self.projector = nn.Sequential(*layers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, y1, y2, is_mixup): |
|
batch_size = y1.shape[0] |
|
|
|
|
|
z1 = self.projector(self.backbone(y1)) |
|
z2 = self.projector(self.backbone(y2)) |
|
|
|
|
|
z1 = (z1 - z1.mean(dim=0)) / z1.std(dim=0) |
|
z2 = (z2 - z2.mean(dim=0)) / z2.std(dim=0) |
|
|
|
|
|
c = z1.T @ z2 |
|
|
|
|
|
c.div_(self.args.batch_size) |
|
torch.distributed.all_reduce(c) |
|
|
|
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() |
|
off_diag = off_diagonal(c).pow_(2).sum() |
|
loss = on_diag + self.args.lambd * off_diag |
|
|
|
if is_mixup: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index = torch.randperm(batch_size).cuda(non_blocking=True) |
|
alpha = np.random.beta(1.0, 1.0) |
|
ym = alpha * y1 + (1. - alpha) * y2[index, :] |
|
zm = self.projector(self.backbone(ym)) |
|
|
|
|
|
zm = (zm - zm.mean(dim=0)) / zm.std(dim=0) |
|
|
|
|
|
cc_m_1 = zm.T @ z1 |
|
cc_m_1.div_(self.args.batch_size) |
|
cc_m_1_gt = alpha*(z1.T @ z1) + (1.-alpha)*(z2[index,:].T @ z1) |
|
cc_m_1_gt.div_(self.args.batch_size) |
|
|
|
cc_m_2 = zm.T @ z2 |
|
cc_m_2.div_(self.args.batch_size) |
|
cc_m_2_gt = alpha*(z2.T @ z2) + (1.-alpha)*(z2[index,:].T @ z2) |
|
cc_m_2_gt.div_(self.args.batch_size) |
|
|
|
|
|
torch.distributed.all_reduce(cc_m_1) |
|
torch.distributed.all_reduce(cc_m_1_gt) |
|
torch.distributed.all_reduce(cc_m_2) |
|
torch.distributed.all_reduce(cc_m_2_gt) |
|
|
|
|
|
lossm = 0.5*self.args.lambd*((cc_m_1-cc_m_1_gt).pow_(2).sum() + (cc_m_2-cc_m_2_gt).pow_(2).sum()) |
|
else: |
|
lossm = torch.zeros(1) |
|
return loss, lossm |
|
|
|
class LARS(optim.Optimizer): |
|
def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, |
|
weight_decay_filter=False, lars_adaptation_filter=False): |
|
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, |
|
eta=eta, weight_decay_filter=weight_decay_filter, |
|
lars_adaptation_filter=lars_adaptation_filter) |
|
super().__init__(params, defaults) |
|
|
|
|
|
def exclude_bias_and_norm(self, p): |
|
return p.ndim == 1 |
|
|
|
@torch.no_grad() |
|
def step(self): |
|
for g in self.param_groups: |
|
for p in g['params']: |
|
dp = p.grad |
|
|
|
if dp is None: |
|
continue |
|
|
|
if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p): |
|
dp = dp.add(p, alpha=g['weight_decay']) |
|
|
|
if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p): |
|
param_norm = torch.norm(p) |
|
update_norm = torch.norm(dp) |
|
one = torch.ones_like(param_norm) |
|
q = torch.where(param_norm > 0., |
|
torch.where(update_norm > 0, |
|
(g['eta'] * param_norm / update_norm), one), one) |
|
dp = dp.mul(q) |
|
|
|
param_state = self.state[p] |
|
if 'mu' not in param_state: |
|
param_state['mu'] = torch.zeros_like(p) |
|
mu = param_state['mu'] |
|
mu.mul_(g['momentum']).add_(dp) |
|
|
|
p.add_(mu, alpha=-g['lr']) |
|
|
|
|
|
|
|
class GaussianBlur(object): |
|
def __init__(self, p): |
|
self.p = p |
|
|
|
def __call__(self, img): |
|
if random.random() < self.p: |
|
sigma = random.random() * 1.9 + 0.1 |
|
return img.filter(ImageFilter.GaussianBlur(sigma)) |
|
else: |
|
return img |
|
|
|
|
|
class Solarization(object): |
|
def __init__(self, p): |
|
self.p = p |
|
|
|
def __call__(self, img): |
|
if random.random() < self.p: |
|
return ImageOps.solarize(img) |
|
else: |
|
return img |
|
|
|
|
|
class Transform: |
|
def __init__(self): |
|
self.transform = transforms.Compose([ |
|
transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomApply( |
|
[transforms.ColorJitter(brightness=0.4, contrast=0.4, |
|
saturation=0.2, hue=0.1)], |
|
p=0.8 |
|
), |
|
transforms.RandomGrayscale(p=0.2), |
|
GaussianBlur(p=1.0), |
|
Solarization(p=0.0), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
self.transform_prime = transforms.Compose([ |
|
transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomApply( |
|
[transforms.ColorJitter(brightness=0.4, contrast=0.4, |
|
saturation=0.2, hue=0.1)], |
|
p=0.8 |
|
), |
|
transforms.RandomGrayscale(p=0.2), |
|
GaussianBlur(p=0.1), |
|
Solarization(p=0.2), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def __call__(self, x): |
|
y1 = self.transform(x) |
|
y2 = self.transform_prime(x) |
|
return y1, y2 |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |