Basic Q: indexing a column of a tensor that could potentially be just a row

I’m a nubie to python and pytorch with a very basic question.

I have a tensor T and I want to access column i of this tensor:


when T is multidimensional (e.g., torch.Size([3, 4])) this works fine, but if it is not (e.g., torch.Size([4])), I get the following error:

IndexError: too many indices for tensor of dimension 1

In the latter case I wish to just access element i of the row vector. My question is: what code should I write so that my indexing can deal with both the M x N shaped tensor and the 1 X N shaped tensor.

Thanks in advance.

You can try index_select: T.index_select(-1, tc.tensor(i)). Also, you need to be careful about the shape of the output of index_select.

1 Like

Works great - thank you : )

Much appreciated.