No speedup when doing backward pass only for a part of the samples in minibatch - why?

I have a function (see below) which modifies the loss function so that it returns only the loss for the K samples in the minibatch with the lowest loss. The idea is to focus in each optimization step on these samples.

So I do first a forward pass to get the loss value for each sample in the mini-batch, then adapt the loss via the fn. “get_adapted_loss_for_minibatch”.

As the adapted loss takes into account only a certain fraction of the samples in the minibatch (I am using currently 60% of the samples), I was expecting that I get also a measurable speedup during training, as the backward step has to be done only for a fraction of the samples in the minibatch.

But unfortunately this is not the case, the training takes practically the same amount of time as when I am using all samples in the minibatch (so when I do not adapt the loss). I am using a ‘densenet121’ network, and training is done on CIFAR-100.

Am I doing something wrong ? Should I disabled autograd for some samples in the minibatch manually ? I though the ‘topk’ function would do that automatically.

def get_adapted_loss_for_minibatch(loss):
    # Returns the loss containing only the samples of the mini-batch with the _lowest_ loss
    # Parameter 'loss' must be a vector containing the per-sample loss for all samples in the (original) minibatch
    minibatch_size = loss.size()[0]
	r = 0.6 * minibatch_size
	# round r to integer, safeguard if r is 0
    r = max(round(r), 1)
    # The 'topk' function returns the loss for the 'r' samples with the _lowest_ loss in the minibtach
    # See documentation at https://pytorch.org/docs/stable/generated/torch.topk.html
    # Note the 'topk' operation is differentiable, see https://stackoverflow.com/questions/67570529/derive-the-gradient-through-torch-topk
    # and https://math.stackexchange.com/questions/4146359/derivative-for-masked-matrix-hadamard-multiplication
    loss_adapted = torch.topk(loss, r, largest = False, sorted = False, dim = 0)[0]
    # return it
    return loss_adapted

Has been resolved.
See python - Pytorch: no speedup when doing backward pass only for a part of the samples in minibatch - why? - Stack Overflow