I have a 3D tensor as following
a = torch.tensor([
[
[ 1, 2, 3],
[ 4, 5, 6]
],
[
[ 7, 8, 9],
[10, 11, 12]
]
])
I would like to select the first row of the first matrix and the second row of the second matrix within the 3D tensor to give the following output:
[
[1,2,3],
[10,11,12],
]
I’m just thinking if there is a way in PyTorch without using for-loops? I tried looking at torch.gather
and torch.index_select
but it doesn’t seem to work with my use case.