This post is to define a Class Weighted Accuracy function(WCA). As questions related to this get asked often, I thought it might help people to post a tool torchers can use and reference here.
Why?
Take, for example, a classification dataset of kittens and puppies with a ratio of 0.2:0.8 kittens to puppies. A model trained on this dataset might show an overall accuracy of 80% by just predicting “puppies” every time.
A WCA function resolves this problem by checking the accuracy of each class separately and weighting them by their distribution for a final overall accuracy score.
A WCA function would give an accuracy of 50% in the kittens and puppies classification training example provided above.
Other usage examples include:
- UNet segmentation;
- RNN or Transformer token classification;
- Signal classification.
The below code can handle getting accuracy for nearly every type of classification problem, including binary. It can do everything, except wash the dishes.
The model outputs should be permuted to ensure they are of size (batch_size, classes, …).
If you’re using logits for binary classification, you’ll need to run outputs = F.sigmoid(outputs)
first. Multi-class classification can work with logits or softmax probabilities.
And the targets should be indices of size (batch_size, …).
The “…” are optional dims, covering k-dimensional cases.
Other Features
This can also be used to obtain the accuracy of each class separately by setting weighted_mean = False
.
The Code
import torch
import torch.nn.functional as F
class CWA():
def __init__(self, weighted_mean=True):
'''
:param weighted_mean: if True, this gives the weighted average of class accuracies;
False gives a tensor of individual class accuracies
'''
self.weighted_mean = weighted_mean
self.reset_sampling()
def sampling(self, outputs, targets):
'''
Stores samplings until self.reset_sampling() is called
:param outputs: model outputs should be tensor of shape (batch, classes, 1-dim, 2-dim, ... k-dim)
where k-dims are optional
:param targets: targets should be integers if indices or float if class probabilities
:return: None
'''
self.sample_outputs = torch.cat([self.sample_outputs, outputs.detach()])
self.sample_targets = torch.cat([self.sample_targets, targets])
def reset_sampling(self):
'''
Resets sampling
:return: None
'''
self.sample_outputs = torch.tensor([])
self.sample_targets = torch.tensor([])
def get_accuracy(self):
'''
Gets class weighted accuracy
:return: If self.weighted_mean, returns a class weighted accuracy, else returns accuracy of each class as vector
'''
# if no sampling taken, will return None
if self.sample_outputs.dim() == 1:
return None
# if binary classification, reframe to two classes
if self.sample_outputs.size(1) == 1:
outputs = torch.cat([(1 - self.sample_outputs), self.sample_outputs], dim=1)
targets = self.sample_targets.unsqueeze(1)
else:
outputs = self.sample_outputs
targets = self.sample_targets
N, C = outputs.size(0), outputs.size(1)
# reshape any k-dims to -1 dim
outputs = outputs.view(N, C, -1)
targets = targets.view(N, -1)
# get predictions
pred = torch.argmax(outputs, dim=1)
# compare with original
correct = (pred == targets).unsqueeze(1)
# get correct based on classes
target_onehot = F.one_hot(targets.to(torch.int64), num_classes=C).transpose(1, 2)
correct = correct.expand_as(target_onehot) * target_onehot
# get the total of each class target
total = target_onehot.sum((0, 2))
# get the proportion of each class
correct = correct.sum((0, 2)) / total
if self.weighted_mean:
return correct.mean()
else:
return correct
----- USAGE -----
# instantiate class
cwa = CWA(weighted_mean=False)
# make some dummy data
outputs = torch.rand((10, 20, 224, 224)) # model outputs size (batch_size, classes, 1-dim, ..., k-dim)
targets = torch.randint(high = 20, size = (10,224,224)) # targets size (batch_size, 1-dim, ..., k-dim)
# collect samples during training
cwa.sampling(outputs, targets)
#call get_accuracy() at any time
print(cwa.get_accuracy())
# get class weighted accuracy
cwa.weighted_mean = True
print(cwa.get_accuracy())
# reset sampling between epochs
cwa.reset_sampling()
Limitations
This only works with targets which are class indices and does not work with class probabilities. Not sure what you have? Ask yourself if your targets are integers. If they are, then this should work just fine.
Please let me know if you have any issues, requests or questions.