The mnist data loaded from dataloader is like: (batch_size,1,28,28), and I want to extract those samples with specific labels: for ex. label=1.
In numpy, it can be done like:
new_label_idx=np.flatnonzero(label==1)
new_data=np.take(data,new_label_idx,axis=0)
However, in pytorch, torch.take does not take axis as its parameters. Is there any elegant way to do this?
Thanks!