Pytorch equivalent of numpy.take()

The mnist data loaded from dataloader is like: (batch_size,1,28,28), and I want to extract those samples with specific labels: for ex. label=1.
In numpy, it can be done like:

new_label_idx=np.flatnonzero(label==1)
new_data=np.take(data,new_label_idx,axis=0)

However, in pytorch, torch.take does not take axis as its parameters. Is there any elegant way to do this?

Thanks!

you can try index_select

x = torch.randn(3, 4)
indices = torch.tensor([0, 2])
torch.index_select(x, 0, indices)
``
3 Likes

Thank you, yuanchao. It works!