Select samples with same labels within a batch

Hi

For each sample within a batch of Mnist data, I need to select one or more samples with the same label.
is there any way to do it without going through the batch with a For loop? I mean a vectorized version
which can use GPU acceleration.

You need to pass the class in the dataloader, then you can just use indexing. There is no magic way of knowing the class otherwise