How can I compute the gradient of a loss function using only a subset of the current batch?
I am using
MSELoss, and I want to impement something like online hard negative sample mining, in which I would sort the losses computed in the forward propagation from all the samples in a batch, select the top 70% of them as hard samples, and then compute the gradients only for those 70% hardest worst samples. Some works claim that it sometime accelerates convergence and provides better results for object detection.
So there might be 2 possibilities:
Compute squared distance between my N outputs and N ground truths. Sort, then select the top 70%, put them in a batch
loss = criterion(prediction, target), and do
loss.backward()This doesn’t seem very efficient since it requires performing forward propagation two times…
lossfor the whole batch, then compute the squared distances between
target(by the way - is there a Pytorch function that would compute a vector of squared distances between N pairs of 2D tensors?
MSELossjust returns one scalar which is an averaged sum of such distances) and sort them And then somehow use a
loss.backward()but only for specific indexes. Does this seem reasonable, and if yes, how to implement it in Pytorch?