Selecting topk loss values using one pass through data and reduction='none'


I’m trying to optimize using only a subset of the losses of the data in a batch. I basically run the model regularly with a big batch with no reduction on loss terms. Then based on some criteria I choose some of the individual loss therms. Then I take a mean of these and backward based on this loss. Is doing this ok?

a sample:

loss_all = run_model(batch_all)
loss_sub = torch.index_select(loss_all, 0, topk_idxs)
loss = loss_sub.mean()

Is this ok? Or for such a use case I have to run with no_grad first, select data, then run model with the selected data?


Ok I did a toy example and this works like selecting the inputs from the beginning. However, the computational cost of backward doesn’t decrease. Is there a way to exclude unwanted samples of the batch from backward based on the loss after doing a forward with all the samples (to speed up the backward pass)?

Hi Sina!

No, there isn’t. The problem is that although in your use case, the batch
elements are independent of one another – that is, their gradients don’t
mix together – backward() doesn’t and can’t know this.

Suppose you had multiple channels that were mixed together in the forward
pass and then you only wanted to compute the gradients of the topk largest
channel losses. backward() would have to process all of the channels to
get the correct gradients – because they mix together. But backward()
can’t tell the independent-batch-elements use case from the
mixed-together-channels use case, so it (thinks it) has to process all of
your batch elements even though you only want gradients for the batch
elements with the topk largest losses.

This approach has the potential to make your backward pass cheaper (but
at the cost of performing the additional forward pass). Also, you are now
only performing the full-batch forward pass without tracking gradients, so
this could also offer some cost savings.

However, depending on the structure of your model and how your gpu (or
cpu) is being used, running your forward / backward pass on just your
topk batch may or may not be significantly cheaper than running it on
your full batch, so you’d have to time the two versions to be sure that such
a scheme actually helps.


K. Frank

Thanks KFrank for your comment.

Actually, I think the batch data is independent in my case (Not sure if batch norm breaks that independence though :thinking:).
I don’t want to do this on the other dimensions like channels that could mix.

Currently, I’m trying to fix this by modifying pack_hook and unpack_hook in saved_tensors. Inside unpack_hook I want to select only the desired batch indices. I’m not sure this is the way but it was the closest path that I found. I would be glad if anyone has a say on this.

Also, I don’t want to use torch.no_grad since the ratio between the selected subset and total batch size is high and this will end up taking more time than just running it without torch.no_grad.