|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.models.resnet import resnet50, resnet18 |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, feature_dim=128, dataset='cifar10', arch='resnet50'): |
|
super(Model, self).__init__() |
|
|
|
self.f = [] |
|
if arch == 'resnet18': |
|
temp_model = resnet18().named_children() |
|
embedding_size = 512 |
|
elif arch == 'resnet50': |
|
temp_model = resnet50().named_children() |
|
embedding_size = 2048 |
|
else: |
|
raise NotImplementedError |
|
|
|
for name, module in temp_model: |
|
if name == 'conv1': |
|
module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) |
|
if dataset == 'cifar10' or dataset == 'cifar100': |
|
if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d): |
|
self.f.append(module) |
|
elif dataset == 'tiny_imagenet' or dataset == 'stl10': |
|
if not isinstance(module, nn.Linear): |
|
self.f.append(module) |
|
|
|
self.f = nn.Sequential(*self.f) |
|
|
|
self.g = nn.Sequential(nn.Linear(embedding_size, 512, bias=False), nn.BatchNorm1d(512), |
|
nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True)) |
|
|
|
def forward(self, x): |
|
x = self.f(x) |
|
feature = torch.flatten(x, start_dim=1) |
|
out = self.g(feature) |
|
return F.normalize(feature, dim=-1), F.normalize(out, dim=-1) |
|
|