The mnist data loaded from dataloader is like: (batch_size,1,28,28), and I want to extract those samples with specific labels: for ex. label=1.
In numpy, it can be done like:
However, in pytorch, torch.take does not take axis as its parameters. Is there any elegant way to do this?