Ohem: compute gradients only for a subset of a batch

How can I compute the gradient of a loss function using only a subset of the current batch?
I am using MSELoss, and I want to impement something like online hard negative sample mining, in which I would sort the losses computed in the forward propagation from all the samples in a batch, select the top 70% of them as hard samples, and then compute the gradients only for those 70% hardest worst samples. Some works claim that it sometime accelerates convergence and provides better results for object detection.

So there might be 2 possibilities:

  1. Compute squared distance between my N outputs and N ground truths. Sort, then select the top 70%, put them in a batch prediction, compute loss = criterion(prediction, target), and do loss.backward() This doesn’t seem very efficient since it requires performing forward propagation two times…

  2. Compute loss for the whole batch, then compute the squared distances between prediction and target (by the way - is there a Pytorch function that would compute a vector of squared distances between N pairs of 2D tensors? MSELoss just returns one scalar which is an averaged sum of such distances) and sort them And then somehow use a loss.backward() but only for specific indexes. Does this seem reasonable, and if yes, how to implement it in Pytorch?

Thanks!

1 Like

For 2, you can implement the vector of squared distances using PyTorch mathematical ops, and then replace the elements of prediction with the corresponding elements of target for the indices that you don’t want to backpropagate. (It’ll still backprop them, but with gradient zero).

1 Like

I am working in the same thing. How do you do hard negative mining now?

1 Like

I guess this is the function you are wanted: Top K gradient for Cross Entropy (OHEM)