How to select particular elements of a 3D tensor based on indices of last dimension in torch

I have a tensor of size 3 x 2 x 3 like
[[[0, 11, 12],
[3, 41, 51]]

[[10, 1, 21],
[31, 4, 51]]

[[40, 14, 2],
[34, 44, 5]] ]

I want to select elements of [0, 1, 2] applied to the last dimension making it
[[0, 3] , [1, 4], [2, 5]]

could you assist me please how to do it? thanks

This should work:

x = torch.tensor([[[0, 11, 12],
                   [3, 41, 51]],
                  [[10, 1, 21],
                   [31, 4, 51]],
                  [[40, 14, 2],
                   [34, 44, 5]]])

out = x[torch.arange(x.size(0)), :, [0, 1, 2]]
print(out)
> tensor([[0, 3],
          [1, 4],
          [2, 5]])