Efficiently backpropagate subset of minibatch

Hi, I have a batch size of 2048 and based on the results from the forward pass I only want to backpropagate half of the minibatch (I have 1024 indices). If I do this naively, this is just as slow as backpropagating the whole minibatch but is there a way to do this efficiently in PyTorch?
Best, Jannis