Delete samples from loss in a differentiable way

I would like to remove “easy” samples with negligible error from the loss, so that they do not dominate the mean loss value. Focal loss in the RetinaNet paper uses weighting to address the issue. I just want to remove those samples.

I suppose, the following is not right, as it changes the loss value in place and disrupts the computation graph.

output = network(images)
loss = loss_function(output)
loss = loss[loss > 1e-6]
loss = torch.mean(loss)
loss.backward()

I know there are some selection functions in PyTorch, but not sure which one is suitable in this case.
What is the easiest way to do this correctly? (In a differentiable way, without disrupting the computation graph)

Thanks.

Hi,

Actually your example does not do any inplace operation ! It would if you were doing loss[loss > 1e-6].zero_().
So your code should work without issue and the entries that were not selected by your mask will just have gradients of 0 in the rest of the network.

Thanks a lot!

So, this expression does not break the computation graph and gradients are properly back propagated as intended (zero for the low loss, normal for the remaining)?

I thought I need to use one of those selection functions so that it is differentiable.
I wonder how it is achieved in such cases.

And in general, which ops/functions are diff and which are not.
Do you know of some doc or blog post that explains these in some detail?

The advanced indexing functions like this are actually implemented with a mix of these selection functions.
(Almost) all the functions are differentiable. Be mindful of the function you use though. A function like indexing for example cannot give gradients for the indices (well they are all 0s), only the tensor that is indexed.