Indexing Matrix

Suppose I have the following tensors:

N = 2
k = 3
d = 2

L = torch.arange(N * k * d * d).view(N, k, d, d)
L
tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],

         [[ 8,  9],
          [10, 11]]],


        [[[12, 13],
          [14, 15]],

         [[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]]]])


index = torch.Tensor([0,1,0,0]).view(N,-1)
index
tensor([[0., 1.],
        [0., 0.]])

I now would like to use the index vector to pick out the corresponding matrices on the second dimension, i.e. I would like to get something like:

tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]],]],


        [[[12, 13],
          [14, 15]],

         [[[[12, 13],
          [14, 15]]])

Any idea how I could achieve this?
Thank you so much!

This should work:

N = 2
k = 3
d = 2

L = torch.arange(N * k * d * d).view(N, k, d, d)
index = torch.tensor([0,1,0,0]).view(N,-1)
L[torch.arange(L.size(0)).unsqueeze(1), index]