Update gradients based on worst loss in set

Hi everyone! I’m still a bit unfamiliar with pytorch so my apologies if this is a dumb question.

I have some data where each datapoint is a set of 10 vectors measuring the state of the same object in different circumstances. In total, the dataset comprises about 20k such sets each relating to a similar (but distinct) object and a label identifying each object as one of 3 classes.

I want to train a simple model to classify of these objects robustly and, in order to do this, I want to compute the cross entropy loss for each vector and update my gradients based on the worst one. I tried doing the following but I think it isn’t computing the gradients before the max() as it isn’t learning at all.

class WorstLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.cross_entropy = nn.CrossEntropyLoss(reduction='none')

    def criterion(self, inputs, target):
        loss = self.cross_entropy(inputs, torch.ones(10).long().cuda() * target)
        return loss.max()

    def forward(self, inputs, target):
        batched_criterion = torch.vmap(self.criterion)
        return batched_criterion(inputs, target)

Is there a better approach to this? I might be missing something obvious, my apologies if that is the case. Any help would be much appreciated. Thank you for your time!

1 Like

Hi! Just to be sure that I understand everything (please correct me if I’m wrong):

  • Parameter inputs of criterion is of shape (10,). All of them are obtained independently from a different input vector.
  • Parameter target of criterion is a scalar, so torch.ones(10).long().cuda() * target is of shape (10,).
  • Parameter inputs of forward is of shape (batch_size, 10).
  • Parameter target of forward is of shape (batch_size,).

I don’t see anything that seems obviously wrong here. As a sanity check, have you tried replacing loss.max() by loss.mean()? Maybe it’s actually the idea of minimizing the worst loss that prevents your model from training. Say that among your 10 vectors, one is always pure noise. It will likely have the worst loss among the 10. Using the worst loss prevents you entirely from training based on any of the other 9 vectors, so you wouldn’t learn anything in this scenario.