Hi everyone! I’m still a bit unfamiliar with pytorch so my apologies if this is a dumb question.
I have some data where each datapoint is a set of 10 vectors measuring the state of the same object in different circumstances. In total, the dataset comprises about 20k such sets each relating to a similar (but distinct) object and a label identifying each object as one of 3 classes.
I want to train a simple model to classify of these objects robustly and, in order to do this, I want to compute the cross entropy loss for each vector and update my gradients based on the worst one. I tried doing the following but I think it isn’t computing the gradients before the max() as it isn’t learning at all.
class WorstLoss(nn.Module):
def __init__(self):
super().__init__()
self.cross_entropy = nn.CrossEntropyLoss(reduction='none')
def criterion(self, inputs, target):
loss = self.cross_entropy(inputs, torch.ones(10).long().cuda() * target)
return loss.max()
def forward(self, inputs, target):
batched_criterion = torch.vmap(self.criterion)
return batched_criterion(inputs, target)
Is there a better approach to this? I might be missing something obvious, my apologies if that is the case. Any help would be much appreciated. Thank you for your time!