How to select columns from 2D tensor

PyTorch supports numpy-style indexing. To select the first 2 columns

a = torch.rand(4,4)
a[:, :2]

should do the trick.

7 Likes