Suppose I have the following tensors:
N = 2
k = 3
d = 2
L = torch.arange(N * k * d * d).view(N, k, d, d)
L
tensor([[[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]]],
[[[12, 13],
[14, 15]],
[[16, 17],
[18, 19]],
[[20, 21],
[22, 23]]]])
index = torch.Tensor([0,1,0,0]).view(N,-1)
index
tensor([[0., 1.],
[0., 0.]])
I now would like to use the index vector to pick out the corresponding matrices on the second dimension, i.e. I would like to get something like:
tensor([[[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],]],
[[[12, 13],
[14, 15]],
[[[[12, 13],
[14, 15]]])
Any idea how I could achieve this?
Thank you so much!