Delete samples from loss in a differentiable way

I would like to remove “easy” samples with negligible error from the loss, so that they do not dominate the mean loss value. Focal loss in the RetinaNet paper uses weighting to address the issue. I just want to remove those samples.

I suppose, the following is not right, as it changes the loss value in place and disrupts the computation graph.

output = network(images)
loss = loss_function(output)
loss = loss[loss > 1e-6]
loss = torch.mean(loss)
loss.backward()

I know there are some selection functions in PyTorch, but not sure which one is suitable in this case.
What is the easiest way to do this correctly? (In a differentiable way, without disrupting the computation graph)

Thanks.

Hi,

Actually your example does not do any inplace operation ! It would if you were doing loss[loss > 1e-6].zero_().
So your code should work without issue and the entries that were not selected by your mask will just have gradients of 0 in the rest of the network.

Thanks a lot!

So, this expression does not break the computation graph and gradients are properly back propagated as intended (zero for the low loss, normal for the remaining)?

I thought I need to use one of those selection functions so that it is differentiable.
I wonder how it is achieved in such cases.

And in general, which ops/functions are diff and which are not.
Do you know of some doc or blog post that explains these in some detail?

The advanced indexing functions like this are actually implemented with a mix of these selection functions.
(Almost) all the functions are differentiable. Be mindful of the function you use though. A function like indexing for example cannot give gradients for the indices (well they are all 0s), only the tensor that is indexed.

I have a similar situation. I’m trying to only keep samples for which the model’s confidence is about a certain threshold. Should the following work?

output = model(images)
loss = loss_function(output, targets)

output = output / softmax_temp
output_probs = torch.softmax(output, dim=-1)
largest_prob, _ = torch.max(output_probs, dim=-1)
mask = torch.gt(output_probs > thr)
loss = loss[mask]
loss = torch.mean(loss)
loss.backward()

The code will run just fine yes.
But no gradients will flow back along the mask as the ops to create it are not differentiable.
But gradients will properly flow back along all the entries in the loss that were kept.

1 Like

Thanks! You actually answered the question I should have asked. I wanted to make sure no gradients flow back along the mask. I take it that if I calculated the mask in a differentiable way, I would have to add a detach or no_grad, to ensure that no gradients flow back along the mask.

Well given that the mask is boolean, it is not really possible to compute it in a differentiable way.
Gradients for discrete types is not really a thing :smiley:

:smiley:

Yes, of course. I meant that instead using a boolean mask, one could use a multiplicative mask.

loss = loss * mask

where mask could be in the range 0…1.

1 Like

Hey all :slight_smile:

@albanD I do have a follow up question.
Its very much related to what was already asked but I struggle with the ideal solution.
I am having an RL agent running through its environment collecting samples.
Some of the samples are from its own predictions. However, the PyTorch model has two separate output heads and for one of those output heads, I want to remove samples from the loss. The removal should be done to not influence any of the gradients for these particular samples within the 64 sized batch. What I’ve done so far is to just set the corresponding entries to zero and then do the .sum() or .mean() operation in order to get the loss scalar value. I couldn’t really find anything about setting entries in the loss vector to zero and its related backprob tree.
So I am having a mask for which those entries selected by the mask should not influence the learning process at all. How should I proceed ?
Thanks in advance

If you override the value, then no gradient will flow back to the original value indeed, similar to masking.

So my way of setting them to zero with the mask is correct. I mean sure the gradient times zero is again zero… :monkey_face:

1 Like