I have a 3D tensor as following
a = torch.tensor([ [ [ 1, 2, 3], [ 4, 5, 6] ], [ [ 7, 8, 9], [10, 11, 12] ] ])
I would like to select the first row of the first matrix and the second row of the second matrix within the 3D tensor to give the following output:
[ [1,2,3], [10,11,12], ]
I’m just thinking if there is a way in PyTorch without using for-loops? I tried looking at
torch.index_select but it doesn’t seem to work with my use case.