I would like to remove “easy” samples with negligible error from the loss, so that they do not dominate the mean loss value. Focal loss in the RetinaNet paper uses weighting to address the issue. I just want to remove those samples.
I suppose, the following is not right, as it changes the loss value in place and disrupts the computation graph.
output = network(images) loss = loss_function(output) loss = loss[loss > 1e-6] loss = torch.mean(loss) loss.backward()
I know there are some selection functions in PyTorch, but not sure which one is suitable in this case.
What is the easiest way to do this correctly? (In a differentiable way, without disrupting the computation graph)