import logging import numpy as np import scipy import torch from sklearn.metrics import average_precision_score, roc_auc_score logger = logging.getLogger(f'main.{__name__}') def metrics(targets, outputs, topk=(1, 5)): """ Adapted from https://github.com/hche11/VGGSound/blob/master/utils.py Calculate statistics including mAP, AUC, and d-prime. Args: output: 2d tensors, (dataset_size, classes_num) - before softmax target: 1d tensors, (dataset_size, ) topk: tuple Returns: metric_dict: a dict of metrics """ metrics_dict = dict() num_cls = outputs.shape[-1] # accuracy@k _, preds = torch.topk(outputs, k=max(topk), dim=1) correct_for_maxtopk = preds == targets.view(-1, 1).expand_as(preds) for k in topk: metrics_dict[f'accuracy_{k}'] = float(correct_for_maxtopk[:, :k].sum() / correct_for_maxtopk.shape[0]) # avg precision, average roc_auc, and dprime targets = torch.nn.functional.one_hot(targets, num_classes=num_cls) # ids of the predicted classes (same as softmax) targets_pred = torch.softmax(outputs, dim=1) targets = targets.numpy() targets_pred = targets_pred.numpy() # one-vs-rest avg_p = [average_precision_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)] try: roc_aucs = [roc_auc_score(targets[:, c], targets_pred[:, c], average=None) for c in range(num_cls)] except ValueError: logger.warning('Weird... Some classes never occured in targets. Do not trust the metrics.') roc_aucs = np.array([0.5]) avg_p = np.array([0]) metrics_dict['mAP'] = np.mean(avg_p) metrics_dict['mROCAUC'] = np.mean(roc_aucs) # Percent point function (ppf) (inverse of cdf — percentiles). metrics_dict['dprime'] = scipy.stats.norm().ppf(metrics_dict['mROCAUC']) * np.sqrt(2) return metrics_dict if __name__ == '__main__': targets = torch.tensor([3, 3, 1, 2, 1, 0]) outputs = torch.tensor([ [1.2, 1.3, 1.1, 1.5], [1.3, 1.4, 1.0, 1.1], [1.5, 1.1, 1.4, 1.3], [1.0, 1.2, 1.4, 1.5], [1.2, 1.3, 1.1, 1.1], [1.2, 1.1, 1.1, 1.1], ]).float() metrics_dict = metrics(targets, outputs, topk=(1, 3)) print(metrics_dict)