How to decide predicted classes in an unbalanced multi-class classification setup?

I have a multi-class classification problem where each sample is associated with one and only one label.
Also, the training classes are unbalanced, some classes have more samples than the others.
During training, I have something like this:

self.classifier.train()
self.loss = nn.CrossEntropyLoss()                                                                                                                                                      
epoch_preds = []
epoch_labels = []
for batch_idx, (data, target, target_weight) in enumerate(data_loader):                                                                                                                                                                                                                                                                                                                                                          
            logits, probs = self.classifier(data)
            self.loss.weight = target_weight                                                                                                                                           
            loss = self.loss(logits, target.view(-1))                                                                                                                                          
            self.optimizer.zero_grad()                                                                                                                                                         
            loss.backward()                                                                                                                                                                    
            self.optimizer.step()                                                                                                                                                              
                                                                                                                                                                                               
            _, preds = torch.max(probs, 1)                                                                                                                                                     
            epoch_labels.extend( list(target.data.numpy().flatten()) )                                                                                                                   
            epoch_preds.extend( list(preds.data.numpy().flatten()) )                                                                                                                     
            mean_loss += loss.data.cpu().numpy() / float(len(data_loader))                                                                                                                     
                                                                                                                                                                                               
if self.epoch % save_freq == 0:                                                                                                                                                        
    self.save_model(os.path.join(save_dir, '%03d.ckpt' % self.epoch))
print "Training Mean Loss : {}".format(mean_loss)                                                                                                                                      
p, r, f1, _ = sklearn.metrics.precision_recall_fscore_support(epoch_labels, epoch_preds)                                                                                                               
self._print_metrics(p, r, f1) 

where self.classifier is an instance of a network class with the following forward():

def forward(self, input):                                                                                                                                                         
        logits = self.network(input)                                                                                                                                               
        probs  = self.softmax(logits)                                                                                                                                                          
        return logits, probs

My questions are:

  1. since softmax() is a monotonically increasing function, it’s equivalent to use logits values to decide the predicted classes, right? i.e., i could’ve said: _, preds = torch.max(logits, 1).

  2. the training samples are unbalanced across classes, in addition to using target_weight to balance the loss computation, what else techniques we could use to enforce effective training for minor classes? I’m still seeing some minor classes that are not properly trained, i.e., the recall on validation set is 0.

  3. anything critical I’m missing in the above snippet?

Thanks!

1 Like
  1. yes

  2. class-balancing (show all classes equally, also show hardest samples more)

Thank you @smth!

By class-balancing, did you mean some non-random sampling in preparing batches? For example, instead of randomly pick n samples for a batch, we pick n/c samples per class (c classes in total) to constitute a batch. Of course, we can adjust the strategy to pick hard cases more often.

Randomly picking n/c samples per class is a valid approach. Or you could do some form of weighted random sampling in order to get roughly n/c samples per class.

1 Like

Thanks a lot for commenting this! I’ll revise the samples in batches

Do you mind sharing some code snippets on how to implement either in pytorch?
I assume it’s related to torch.utils.data.sampler, but not sure which specific sampler to use for this per-class sampling. Thanks!

1 Like

@smth! Thank you for the useful link!