How to select arbitrary rows from a tensor?

I have a 3D tensor as following

a = torch.tensor([
         [
          [ 1,  2,  3],
          [ 4,  5,  6]
         ],
         [
          [ 7,  8,  9],
          [10, 11, 12]
         ]
])

I would like to select the first row of the first matrix and the second row of the second matrix within the 3D tensor to give the following output:

[
    [1,2,3],
    [10,11,12],
]

I’m just thinking if there is a way in PyTorch without using for-loops? I tried looking at torch.gather and torch.index_select but it doesn’t seem to work with my use case.

Direct indexing will work:

a[[0, 1], [0, 1]]
# tensor([[ 1,  2,  3],
#         [10, 11, 12]])

@ptrblck Thank you very much. :slight_smile: