I have a multi-class classification problem where each sample is associated with one and only one label.
Also, the training classes are unbalanced, some classes have more samples than the others.
During training, I have something like this:
self.classifier.train()
self.loss = nn.CrossEntropyLoss()
epoch_preds = []
epoch_labels = []
for batch_idx, (data, target, target_weight) in enumerate(data_loader):
logits, probs = self.classifier(data)
self.loss.weight = target_weight
loss = self.loss(logits, target.view(-1))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
_, preds = torch.max(probs, 1)
epoch_labels.extend( list(target.data.numpy().flatten()) )
epoch_preds.extend( list(preds.data.numpy().flatten()) )
mean_loss += loss.data.cpu().numpy() / float(len(data_loader))
if self.epoch % save_freq == 0:
self.save_model(os.path.join(save_dir, '%03d.ckpt' % self.epoch))
print "Training Mean Loss : {}".format(mean_loss)
p, r, f1, _ = sklearn.metrics.precision_recall_fscore_support(epoch_labels, epoch_preds)
self._print_metrics(p, r, f1)
where self.classifier is an instance of a network class with the following forward():
def forward(self, input):
logits = self.network(input)
probs = self.softmax(logits)
return logits, probs
My questions are:
-
since softmax() is a monotonically increasing function, it’s equivalent to use logits values to decide the predicted classes, right? i.e., i could’ve said: _, preds = torch.max(logits, 1).
-
the training samples are unbalanced across classes, in addition to using target_weight to balance the loss computation, what else techniques we could use to enforce effective training for minor classes? I’m still seeing some minor classes that are not properly trained, i.e., the recall on validation set is 0.
-
anything critical I’m missing in the above snippet?
Thanks!