File size: 1,656 Bytes
a84a65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class WeightedCrossEntropy(nn.CrossEntropyLoss):

    def __init__(self, weights, **pytorch_ce_loss_args) -> None:
        super().__init__(reduction='none', **pytorch_ce_loss_args)
        self.weights = weights

    def __call__(self, outputs, targets, to_weight=True):
        loss = super().__call__(outputs, targets)
        if to_weight:
            return (loss * self.weights[targets]).sum() / self.weights[targets].sum()
        else:
            return loss.mean()


if __name__ == '__main__':
    x = torch.randn(10, 5)
    target = torch.randint(0, 5, (10,))
    weights = torch.tensor([1., 2., 3., 4., 5.])

    # criterion_weighted = nn.CrossEntropyLoss(weight=weights)
    # loss_weighted = criterion_weighted(x, target)

    # criterion_weighted_manual = nn.CrossEntropyLoss(reduction='none')
    # loss_weighted_manual = criterion_weighted_manual(x, target)
    # print(loss_weighted, loss_weighted_manual.mean())
    # loss_weighted_manual = (loss_weighted_manual * weights[target]).sum() / weights[target].sum()
    # print(loss_weighted, loss_weighted_manual)
    # print(torch.allclose(loss_weighted, loss_weighted_manual))

    pytorch_weighted = nn.CrossEntropyLoss(weight=weights)
    pytorch_unweighted = nn.CrossEntropyLoss()
    custom = WeightedCrossEntropy(weights)

    assert torch.allclose(pytorch_weighted(x, target), custom(x, target, to_weight=True))
    assert torch.allclose(pytorch_unweighted(x, target), custom(x, target, to_weight=False))
    print(custom(x, target, to_weight=True), custom(x, target, to_weight=False))