Cyril666's picture
First model version
6250360
import torch
from torch import nn
from torch.nn import functional as F
def Focal_Loss(pred, gt):
# print('yes!!')
ce = nn.CrossEntropyLoss()
alpha = 0.25
gamma = 2
# logp = ce(input, target)
p = torch.sigmoid(pred)
loss = -alpha * (1 - p) ** gamma * (gt * torch.log(p)) - \
(1 - alpha) * p ** gamma * ((1 - gt) * torch.log(1 - p))
return loss.mean()
# pred =torch.sigmoid(pred)
# pos_inds = gt.eq(1).float()
# neg_inds = gt.lt(1).float()
#
# loss = 0
#
# pos_loss = torch.log(pred + 1e-10) * torch.pow(pred, 2) * pos_inds
# # neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
# neg_loss = torch.log(1 - pred) * torch.pow(1 - pred, 2) * neg_inds
#
# num_pos = pos_inds.float().sum()
# num_neg = neg_inds.float().sum()
#
# pos_loss = pos_loss.sum()
# neg_loss = neg_loss.sum()
#
# if num_pos == 0:
# loss = loss - neg_loss
# else:
# # loss = loss - (pos_loss + neg_loss) / (num_pos)
# loss = loss - (pos_loss + neg_loss )
# return loss * 5
# if weight is not None and weight.sum() > 0:
# return (losses * weight).sum() / weight.sum()
# else:
# assert losses.numel() != 0
# return losses.mean()