Delete samples from loss in a differentiable way

I would like to remove “easy” samples with negligible error from the loss, so that they do not dominate the mean loss value. Focal loss in the RetinaNet paper uses weighting to address the issue. I just want to remove those samples.

I suppose, the following is not right, as it changes the loss value in place and disrupts the computation graph.

output = network(images)
loss = loss_function(output)
loss = loss[loss > 1e-6]
loss = torch.mean(loss)
loss.backward()

I know there are some selection functions in PyTorch, but not sure which one is suitable in this case.
What is the easiest way to do this correctly? (In a differentiable way, without disrupting the computation graph)

Thanks.

Hi,

Actually your example does not do any inplace operation ! It would if you were doing loss[loss > 1e-6].zero_().
So your code should work without issue and the entries that were not selected by your mask will just have gradients of 0 in the rest of the network.

Thanks a lot!

So, this expression does not break the computation graph and gradients are properly back propagated as intended (zero for the low loss, normal for the remaining)?

I thought I need to use one of those selection functions so that it is differentiable.
I wonder how it is achieved in such cases.

And in general, which ops/functions are diff and which are not.
Do you know of some doc or blog post that explains these in some detail?