Ignoring large losses when backpropagating

I’m working on an image classification problem. I’m using a collection of transforms that occasionally distort certain input images beyond recognition (e.g. randomly cropping a nondescript portion of the input) and training on these transformed inputs leads to an erroneous learning step (it essentially updates the parameters in a random direction). More generally, you can think of this as having noisy labels, in a way that’s difficult to correct.

I’d like to deal with this by essentially ignoring (for the purposes of backprop) datapoints with large losses from within each minibatch after a certain training step. Acceptable solutions would include ignoring losses above / capping losses at a certain cutoff or ignoring the largest K losses in the minibatch.

Does anyone have any suggestions about how best to implement this?

Much appreciated.

Here are my attempts for your use case.

Hard-threshold capping:

import torch
loss = torch.rand(6, requires_grad=True) # lets say we have a loss tensor
torch.clip(loss, max=0.5).sum().backward() # capping loss to 0.5

# loss.grad contains 0 for loss values greater than 0.5

Ignoring top-k losses:

import torch
loss = torch.rand(6, requires_grad=True)
k = 3
sorted_loss, indices = torch.sort(loss)
sorted_loss[:-k].sum().backward()

# loss.grad contains 0 for top-k losses
1 Like