How can I compute the gradient of a loss function using only a subset of the current batch?
I am using MSELoss
, and I want to impement something like online hard negative sample mining, in which I would sort the losses computed in the forward propagation from all the samples in a batch, select the top 70% of them as hard samples, and then compute the gradients only for those 70% hardest worst samples. Some works claim that it sometime accelerates convergence and provides better results for object detection.
So there might be 2 possibilities:
-
Compute squared distance between my N outputs and N ground truths. Sort, then select the top 70%, put them in a batch
prediction
, computeloss = criterion(prediction, target)
, and doloss.backward()
This doesn’t seem very efficient since it requires performing forward propagation two times… -
Compute
loss
for the whole batch, then compute the squared distances betweenprediction
andtarget
(by the way - is there a Pytorch function that would compute a vector of squared distances between N pairs of 2D tensors?MSELoss
just returns one scalar which is an averaged sum of such distances) and sort them And then somehow use aloss.backward()
but only for specific indexes. Does this seem reasonable, and if yes, how to implement it in Pytorch?
Thanks!