How do I get the accuracy of an unbalanced dataset or segmentation task?

This post is to define a Class Weighted Accuracy function(WCA). As questions related to this get asked often, I thought it might help people to post a tool torchers can use and reference here.

Why?
Take, for example, a classification dataset of kittens and puppies with a ratio of 0.2:0.8 kittens to puppies. A model trained on this dataset might show an overall accuracy of 80% by just predicting “puppies” every time.

A WCA function resolves this problem by checking the accuracy of each class separately and weighting them by their distribution for a final overall accuracy score.

A WCA function would give an accuracy of 50% in the kittens and puppies classification training example provided above.

Other usage examples include:

  • UNet segmentation;
  • RNN or Transformer token classification;
  • Signal classification.

The below code can handle getting accuracy for nearly every type of classification problem, including binary. It can do everything, except wash the dishes.

The model outputs should be permuted to ensure they are of size (batch_size, classes, …).

If you’re using logits for binary classification, you’ll need to run outputs = F.sigmoid(outputs) first. Multi-class classification can work with logits or softmax probabilities.

And the targets should be indices of size (batch_size, …).

The “…” are optional dims, covering k-dimensional cases.

Other Features
This can also be used to obtain the accuracy of each class separately by setting weighted_mean = False.

The Code

import torch
import torch.nn.functional as F


class CWA():
    def __init__(self, weighted_mean=True):
        '''
        :param weighted_mean: if True, this gives the weighted average of class accuracies;
                                False gives a tensor of individual class accuracies
        '''

        self.weighted_mean = weighted_mean

        self.reset_sampling()

    def sampling(self, outputs, targets):
        '''
        Stores samplings until self.reset_sampling() is called
        :param outputs: model outputs should be tensor of shape (batch, classes, 1-dim, 2-dim, ... k-dim)
                        where k-dims are optional
        :param targets: targets should be integers if indices or float if class probabilities
        :return: None
        '''

        self.sample_outputs = torch.cat([self.sample_outputs, outputs.detach()])
        self.sample_targets = torch.cat([self.sample_targets, targets])

    def reset_sampling(self):
        '''
        Resets sampling
        :return: None
        '''
        self.sample_outputs = torch.tensor([])
        self.sample_targets = torch.tensor([])

    def get_accuracy(self):
        '''
        Gets class weighted accuracy
        :return: If self.weighted_mean, returns a class weighted accuracy, else returns accuracy of each class as vector
        '''

        # if no sampling taken, will return None
        if self.sample_outputs.dim() == 1:
            return None

        # if binary classification, reframe to two classes
        if self.sample_outputs.size(1) == 1:
            outputs = torch.cat([(1 - self.sample_outputs), self.sample_outputs], dim=1)
            targets = self.sample_targets.unsqueeze(1)
        else:
            outputs = self.sample_outputs
            targets = self.sample_targets

        N, C = outputs.size(0), outputs.size(1)

        # reshape any k-dims to -1 dim
        outputs = outputs.view(N, C, -1)
        targets = targets.view(N, -1)

        # get predictions
        pred = torch.argmax(outputs, dim=1)

        # compare with original
        correct = (pred == targets).unsqueeze(1)

        # get correct based on classes
        target_onehot = F.one_hot(targets.to(torch.int64), num_classes=C).transpose(1, 2)
        correct = correct.expand_as(target_onehot) * target_onehot

        # get the total of each class target
        total = target_onehot.sum((0, 2))

        # get the proportion of each class
        correct = correct.sum((0, 2)) / total
        if self.weighted_mean:
            return correct.mean()
        else:
            return correct

----- USAGE -----

# instantiate class
cwa = CWA(weighted_mean=False)

# make some dummy data
outputs = torch.rand((10, 20, 224, 224)) # model outputs size (batch_size, classes, 1-dim, ..., k-dim)
targets = torch.randint(high = 20, size = (10,224,224)) # targets size (batch_size, 1-dim, ..., k-dim)

# collect samples during training
cwa.sampling(outputs, targets)

#call get_accuracy() at any time
print(cwa.get_accuracy())

# get class weighted accuracy
cwa.weighted_mean = True
print(cwa.get_accuracy())

# reset sampling between epochs
cwa.reset_sampling()

Limitations

This only works with targets which are class indices and does not work with class probabilities. Not sure what you have? Ask yourself if your targets are integers. If they are, then this should work just fine.

Please let me know if you have any issues, requests or questions.

Here is a function for calculating the weight vector used by CrossEntropyLoss based on a set of targets with integers.

def scaled_weights(targets):
    '''
    Gets class weights
    :param targets: pass in the matrix of size (batch, classes) containing correct targets or
        labels as one_hot
    :return: returns a vector for weight rescaling for use in cross entropy loss function
    '''

    N, C = targets.size(0), targets.size(1)

    # get vector of total for each class
    totals = targets.sum(0)

    # get the scaled vector

    #address division by zero
    totals = torch.where(totals==0, N, totals)
    
    scaled_weight = N/totals/C

    return scaled_weight


num_classes = 10
batch_size = 100

# create dummy outputs and integer targets
model_outputs = torch.randn((batch_size, num_classes), requires_grad =True)
targets = torch.empty(batch_size, dtype=torch.long).random_(num_classes)

# pass targets into function, after converting to one_hot
scaled_weight = scaled_weights(F.one_hot(targets, num_classes))

#define loss function
loss_funct = torch.nn.CrossEntropyLoss(weight=scaled_weight)

#calculate loss
loss = loss_funct(model_outputs, targets)