Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
def Focal_Loss(pred, gt): | |
# print('yes!!') | |
ce = nn.CrossEntropyLoss() | |
alpha = 0.25 | |
gamma = 2 | |
# logp = ce(input, target) | |
p = torch.sigmoid(pred) | |
loss = -alpha * (1 - p) ** gamma * (gt * torch.log(p)) - \ | |
(1 - alpha) * p ** gamma * ((1 - gt) * torch.log(1 - p)) | |
return loss.mean() | |
# pred =torch.sigmoid(pred) | |
# pos_inds = gt.eq(1).float() | |
# neg_inds = gt.lt(1).float() | |
# | |
# loss = 0 | |
# | |
# pos_loss = torch.log(pred + 1e-10) * torch.pow(pred, 2) * pos_inds | |
# # neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds | |
# neg_loss = torch.log(1 - pred) * torch.pow(1 - pred, 2) * neg_inds | |
# | |
# num_pos = pos_inds.float().sum() | |
# num_neg = neg_inds.float().sum() | |
# | |
# pos_loss = pos_loss.sum() | |
# neg_loss = neg_loss.sum() | |
# | |
# if num_pos == 0: | |
# loss = loss - neg_loss | |
# else: | |
# # loss = loss - (pos_loss + neg_loss) / (num_pos) | |
# loss = loss - (pos_loss + neg_loss ) | |
# return loss * 5 | |
# if weight is not None and weight.sum() > 0: | |
# return (losses * weight).sum() / weight.sum() | |
# else: | |
# assert losses.numel() != 0 | |
# return losses.mean() |