Ignoring large losses when backpropagating

Here are my attempts for your use case.

Hard-threshold capping:

import torch
loss = torch.rand(6, requires_grad=True) # lets say we have a loss tensor
torch.clip(loss, max=0.5).sum().backward() # capping loss to 0.5

# loss.grad contains 0 for loss values greater than 0.5

Ignoring top-k losses:

import torch
loss = torch.rand(6, requires_grad=True)
k = 3
sorted_loss, indices = torch.sort(loss)
sorted_loss[:-k].sum().backward()

# loss.grad contains 0 for top-k losses
1 Like