wgcban's picture
Upload 98 files
803ef9e
import torch
import torch.nn.functional as F
from .whitening import Whitening2d
from .base import BaseMethod
from .norm_mse import norm_mse_loss
class WMSE(BaseMethod):
""" implements W-MSE loss """
def __init__(self, cfg):
""" init whitening transform """
super().__init__(cfg)
self.whitening = Whitening2d(cfg.emb, eps=cfg.w_eps, track_running_stats=False)
self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss
self.w_iter = cfg.w_iter
self.w_size = cfg.bs if cfg.w_size is None else cfg.w_size
def forward(self, samples):
bs = len(samples[0])
h = [self.model(x.cuda(non_blocking=True)) for x in samples]
h = self.head(torch.cat(h))
loss = 0
for _ in range(self.w_iter):
z = torch.empty_like(h)
perm = torch.randperm(bs).view(-1, self.w_size)
for idx in perm:
for i in range(len(samples)):
z[idx + i * bs] = self.whitening(h[idx + i * bs])
for i in range(len(samples) - 1):
for j in range(i + 1, len(samples)):
x0 = z[i * bs : (i + 1) * bs]
x1 = z[j * bs : (j + 1) * bs]
loss += self.loss_f(x0, x1)
loss /= self.w_iter * self.num_pairs
return loss