I have a function (see below) which modifies the loss function so that it returns only the loss for the K samples in the minibatch with the lowest loss. The idea is to focus in each optimization step on these samples.
So I do first a forward pass to get the loss value for each sample in the mini-batch, then adapt the loss via the fn. “get_adapted_loss_for_minibatch”.
As the adapted loss takes into account only a certain fraction of the samples in the minibatch (I am using currently 60% of the samples), I was expecting that I get also a measurable speedup during training, as the backward step has to be done only for a fraction of the samples in the minibatch.
But unfortunately this is not the case, the training takes practically the same amount of time as when I am using all samples in the minibatch (so when I do not adapt the loss). I am using a ‘densenet121’ network, and training is done on CIFAR-100.
Am I doing something wrong ? Should I disabled autograd for some samples in the minibatch manually ? I though the ‘topk’ function would do that automatically.
def get_adapted_loss_for_minibatch(loss):
# Returns the loss containing only the samples of the mini-batch with the _lowest_ loss
# Parameter 'loss' must be a vector containing the per-sample loss for all samples in the (original) minibatch
minibatch_size = loss.size()[0]
r = 0.6 * minibatch_size
# round r to integer, safeguard if r is 0
r = max(round(r), 1)
# The 'topk' function returns the loss for the 'r' samples with the _lowest_ loss in the minibtach
# See documentation at https://pytorch.org/docs/stable/generated/torch.topk.html
# Note the 'topk' operation is differentiable, see https://stackoverflow.com/questions/67570529/derive-the-gradient-through-torch-topk
# and https://math.stackexchange.com/questions/4146359/derivative-for-masked-matrix-hadamard-multiplication
loss_adapted = torch.topk(loss, r, largest = False, sorted = False, dim = 0)[0]
# return it
return loss_adapted