darklord25's picture
Initial Commit
96ed3dd verified
# -*- coding: utf-8 -*-
import timm
import os
import sys
import argparse
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import linalg as LA
from models.classification_heads import ClassificationHead
from models.R2D2_embedding import R2D2Embedding
from models.protonet_embedding import ProtoNetEmbedding
from models.ResNet12_embedding import resnet12
import torch.nn as nn
from utils import set_gpu, Timer, count_accuracy, check_dir, log
import warnings
import wandb
from itertools import combinations
from torchsummary import summary
warnings.filterwarnings("ignore")
def one_hot(indices, depth):
"""
Returns a one-hot tensor.
This is a PyTorch equivalent of Tensorflow's tf.one_hot.
Parameters:
indices: a (n_batch, m) Tensor or (m) Tensor.
depth: a scalar. Represents the depth of the one hot dimension.
Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
"""
encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
index = indices.view(indices.size()+torch.Size([1]))
encoded_indicies = encoded_indicies.scatter_(1, index, 1)
return encoded_indicies
def seed_everything(seed: int):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def euclidean_dist(x, y):
# x: N x D
# y: M x D
n = x.size(0)
m = y.size(0)
d = x.size(1)
assert d == y.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
return torch.pow(x - y, 2).sum(2)
def cosine_dist(x, y):
# x: N x D
# y: M x D
n = x.size(0)
m = y.size(0)
d = x.size(1)
assert d == y.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
cos = nn.CosineSimilarity(dim=2, eps=1e-6)
out = 1 - cos(x,y)
return out
def get_model(options):
# Choose the embedding network
if options.network == 'ProtoNet':
network = ProtoNetEmbedding().cuda()
elif options.network == 'R2D2':
network = R2D2Embedding().cuda()
elif options.network == 'ResNet':
if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
network = resnet12(avg_pool=False, drop_rate=0.1,
dropblock_size=5,num_layer=options.num_layer).cuda()
network = torch.nn.DataParallel(network) # , device_ids=[1, 2])
else:
network = resnet12(avg_pool=False, drop_rate=0.1,
dropblock_size=2,num_layer=options.num_layer).cuda()
else:
print("Cannot recognize the network type")
assert(False)
# Choose the classification head
if options.head == 'Subspace':
cls_head = ClassificationHead(base_learner='Subspace').cuda()
elif options.head == 'ProtoNet':
cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
elif options.head == 'Ridge':
cls_head = ClassificationHead(base_learner='Ridge').cuda()
elif options.head == 'R2D2':
cls_head = ClassificationHead(base_learner='R2D2').cuda()
elif options.head == 'SVM':
cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
else:
print("Cannot recognize the dataset type")
assert(False)
return (network, cls_head)
def get_dataset(options):
# Choose the embedding network
if options.dataset == 'miniImageNet':
from dataloader.mini_imagenet import MiniImageNet, FewShotDataloader
# change it to train only, this is including the validation set
dataset_train = MiniImageNet(phase='trainval')
dataset_val = MiniImageNet(phase='test')
data_loader = FewShotDataloader
elif options.dataset == 'tieredImageNet':
from dataloader.tiered_imagenet import tieredImageNet, FewShotDataloader
dataset_train = tieredImageNet(phase='train')
dataset_val = tieredImageNet(phase='test')
data_loader = FewShotDataloader
elif options.dataset == 'CIFAR_FS':
from dataloader.CIFAR_FS import CIFAR_FS, FewShotDataloader
dataset_train = CIFAR_FS(phase='train')
dataset_val = CIFAR_FS(phase='test')
data_loader = FewShotDataloader
elif options.dataset == 'Chest':
from dataloader.chest import Chest, FewShotDataloader
dataset_train = Chest(phase='train')
dataset_val = Chest(phase='val')
data_loader = FewShotDataloader
else:
print("Cannot recognize the dataset type")
assert(False)
return (dataset_train, dataset_val, data_loader)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num-epoch', type=int, default=80,
help='number of training epochs')
parser.add_argument('--save-epoch', type=int, default=5,
help='frequency of model saving')
parser.add_argument('--train-shot', type=int, default=5,
help='number of support examples per training class')
parser.add_argument('--val-shot', type=int, default=5,
help='number of support examples per validation class')
parser.add_argument('--train-query', type=int, default=5,
help='number of query examples per training class')
parser.add_argument('--val-episode', type=int, default=600,
help='number of episodes per validation')
parser.add_argument('--val-query', type=int, default=5,
help='number of query examples per validation class')
parser.add_argument('--train-way', type=int, default=3,
help='number of classes in one training episode')
parser.add_argument('--test-way', type=int, default=3,
help='number of classes in one test (or validation) episode')
parser.add_argument('--save-path', default='experiments')
parser.add_argument('--wandbexperiment', default="group5_subspace30",type=str)
parser.add_argument('--gpu', default='0') # using 4 gpus
parser.add_argument('--num_layer', type=int, default=30,
help='number of linear layer')
# parser.add_argument('--gpu', default='0,1,2,3') # using 4 gpus
parser.add_argument('--network', type=str, default='ResNet',
help='choose which embedding network to use. ResNet')
parser.add_argument('--head', type=str, default='Subspace',
help='choose which classification head to use. Subspace, ProtoNet, R2D2, SVM')
parser.add_argument('--dataset', type=str, default='Chest',
help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
parser.add_argument('--episodes-per-batch', type=int, default=1,
help='number of episodes per batch')
parser.add_argument('--eps', type=float, default=0.0,
help='epsilon of label smoothing')
parser.add_argument('--wandb', action="store_true")
parser.add_argument("--wandbkey", type=str,
default='db1158429a436f94565ac9eadecc6afe9e5a0b8f',
help='Wandb project key')
# python train_my.py --gpu 2 --dataset Chest --num_layer 5
opt = parser.parse_args()
seed_everything(42)
print(opt)
opt.save_path = os.path.join(opt.save_path,opt.wandbexperiment)
if opt.wandb:
os.system('wandb login {}'.format(opt.wandbkey))
wandb.init(name=opt.wandbexperiment,
project='chest-few-shot-final')
wandb.config.update(opt)
(dataset_train, dataset_val, data_loader) = get_dataset(opt)
# Dataloader of Gidaris & Komodakis (CVPR 2018)
dloader_train = data_loader(
dataset=dataset_train,
nKnovel=opt.train_way,
nKbase=0,
nExemplars=opt.train_shot, # num training examples per novel category
# num test examples for all the novel categories
nTestNovel=opt.train_way * opt.train_query,
nTestBase=0, # num test examples for all the base categories
batch_size=opt.episodes_per_batch,
num_workers=15,
epoch_size=opt.episodes_per_batch * 1000, # num of batches per epoch
)
dloader_val = data_loader(
dataset=dataset_val,
nKnovel=opt.test_way,
nKbase=0,
nExemplars=opt.val_shot, # num training examples per novel category
# num test examples for all the novel categories
nTestNovel=opt.val_query * opt.test_way,
nTestBase=0, # num test examples for all the base categories
batch_size=1,
num_workers=15,
epoch_size=1 * opt.val_episode, # num of batches per epoch
)
set_gpu(opt.gpu)
check_dir('./experiments/')
check_dir(opt.save_path)
log_file_path = os.path.join(opt.save_path, "train_log.txt")
log(log_file_path, str(vars(opt)))
(embedding_net, cls_head) = get_model(opt)
optimizer = torch.optim.SGD(embedding_net.parameters(),lr=3e-3)
def lambda_epoch(e): return 1.0 if e < 12 else (
0.025 if e < 30 else 0.0032 if e < 45 else (0.0014 if e < 57 else (0.00052)))
## tieredimagenet###
# lambda_epoch = lambda e: 1.0 if e < 20 else (
# 0.012 if e < 45 else 0.0052 if e < 59 else (0.00054 if e < 68 else (0.00012)))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda_epoch, last_epoch=-1)
max_val_acc = 0.0
timer = Timer()
x_entropy = torch.nn.CrossEntropyLoss()
index = list(combinations([i for i in range(opt.num_layer)], 2))
for epoch in range(1, opt.num_epoch + 1):
for param_group in optimizer.param_groups:
epoch_learning_rate = param_group['lr']
log(log_file_path, 'Train Epoch: {}\tLearning Rate: {:.4f}'.format(
epoch, epoch_learning_rate))
_, _ = [x.train() for x in (embedding_net, cls_head)]
train_accuracies = []
train_losses = []
train_n_support = opt.train_way * opt.train_shot
train_n_query = opt.train_way * opt.train_query
for i, batch in enumerate(tqdm(dloader_train(epoch)), 1):
data_support, labels_support, data_query, labels_query, _, _ = [
x.cuda() for x in batch]
list_emb_query = embedding_net(data_query.view(
[-1] + list(data_query.shape[-3:]))) # [100, 2560]
list_emb_support = embedding_net(data_support.view(
[-1] + list(data_support.shape[-3:]))) # [100, 3, 32, 32] -> [100, 2560]
loss_weights = 0.
for ind in index:
loss_weights += torch.abs(F.cosine_similarity(getattr(embedding_net,f'linear{ind[0]}_1').weight.view(-1),getattr(embedding_net,f'linear{ind[1]}_1').weight.view(-1),dim=0))
log_p_y = torch.zeros(
opt.episodes_per_batch * opt.train_way * opt.train_query, opt.train_way).cuda()
for emb_support,emb_query in zip(list_emb_support, list_emb_query):
# emb_support = emb_support.view(
# opt.episodes_per_batch, train_n_support, -1) # [4, 25, 2560]
if opt.train_shot == 1:
emb_support = emb_support.view(
opt.episodes_per_batch, opt.train_way, -1) # [4,5,5,2560] --> [4, 5, 20]
else:
emb_support = emb_support.view(
opt.episodes_per_batch, opt.train_way, opt.train_shot, -1).mean(2) # [4,5,5,2560] --> [4, 5, 20]
emb_query = emb_query.view(
opt.episodes_per_batch, train_n_query, -1) # [4, 25, 2560]
dists = torch.stack(
[euclidean_dist(emb_query[i], emb_support[i]) for i in range(opt.episodes_per_batch)]) # [4,25,5]
log_p_y += F.softmax(-dists,
dim=2).view(opt.episodes_per_batch* opt.train_way* opt.train_query, -1) # [100,5]
log_p_y /= opt.num_layer
smoothed_one_hot = one_hot(
labels_query.view(-1), opt.train_way) # [100,5]
loss = x_entropy(
log_p_y.view(-1, opt.train_way), labels_query.view(-1))
acc, _ = count_accuracy(
log_p_y.view(-1, opt.train_way), labels_query.view(-1))
train_accuracies.append(acc.item())
train_losses.append(loss.item())
if (i % 100 == 0):
train_acc_avg = np.mean(np.array(train_accuracies))
log(log_file_path, 'Train Epoch: {}\tBatch: [{}/{}]\tLoss: {:.4f}\tAccuracy: {:.2f} % ({:.2f} %)'.format(
epoch, i, len(dloader_train), loss.item(), train_acc_avg, acc))
if opt.wandb:
wandb.log({'Epoch': epoch,
'lr': optimizer.param_groups[0]['lr'],"Loss":loss.item(),"Avg Accuracy":train_acc_avg,'Accuracy':acc,
'cosine loss':loss_weights})
optimizer.zero_grad()
loss += loss_weights
loss.backward()
optimizer.step()
# Evaluate on the validation split
_, _ = [x.eval() for x in (embedding_net, cls_head)]
val_accuracies = []
val_losses = []
with torch.no_grad():
for i, batch in enumerate(tqdm(dloader_val(epoch)), 1):
data_support, labels_support, data_query, labels_query, _, _ = [
x.cuda() for x in batch]
test_n_support = opt.test_way * opt.val_shot
test_n_query = opt.test_way * opt.val_query
list_emb_support = embedding_net(data_support.view(
[-1] + list(data_support.shape[-3:])))
list_emb_query = embedding_net(data_query.view(
[-1] + list(data_query.shape[-3:])))
logit_query = torch.zeros(test_n_query, opt.test_way).cuda()
for emb_support, emb_query in zip(list_emb_support, list_emb_query):
# print(emb_support.size())
emb_support = emb_support.view(1, test_n_support, -1)
# print(emb_support.size())
emb_support = emb_support.view(
1, opt.train_way, opt.train_shot, -1).mean(2) # [4, 5, 20]
emb_query = emb_query.view(1, test_n_query, -1)
# print(emb_support.size(),emb_query.size())
dists = torch.stack(
[euclidean_dist(emb_query[i], emb_support[i]) for i in range(emb_query.size(0))])
logit_query += F.softmax(-dists, dim=2).view(1 *
opt.test_way * opt.val_query, -1) # []
logit_query /= opt.num_layer
loss = x_entropy(
logit_query.view(-1, opt.test_way), labels_query.view(-1))
acc, _ = count_accuracy(
logit_query.view(-1, opt.test_way), labels_query.view(-1))
val_accuracies.append(acc.item())
val_losses.append(loss.item())
val_acc_avg = np.mean(np.array(val_accuracies))
val_acc_ci95 = 1.96 * \
np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode)
val_loss_avg = np.mean(np.array(val_losses))
if val_acc_avg > max_val_acc:
max_val_acc = val_acc_avg
torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()},
os.path.join(opt.save_path, 'best_model.pth'))
log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)'
.format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
else:
log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %'
.format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
if opt.wandb:
wandb.log({"Validation Loss":val_loss_avg,"Val Avg Accuracy":val_acc_avg})
torch.save({'embedding': embedding_net.state_dict(
), 'head': cls_head.state_dict()}, os.path.join(opt.save_path, 'last_epoch.pth'))
if epoch % opt.save_epoch == 0:
torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict(
)}, os.path.join(opt.save_path, 'epoch_{}.pth'.format(epoch)))
log(log_file_path, 'Elapsed Time: {}/{}\n'.format(timer.measure(),
timer.measure(epoch / float(opt.num_epoch))))
# lr_scheduler.step()