4D Tensor indexing with `gather()`

Currently I would like to perform this numpy equivalent command

a = np.random.randn(5, 4, 3, 3)
a[range(5), :, [0, 1, 1, 2, 0], [1, 2, 0, 1, 0]]

I am a bit confused how to use torch.gather() to achieve this.

It’s a bit involved to do it currently with pytorch.
A proof of concept that was working before some pytorch changes can be seen in https://gist.github.com/fmassa/f8158d1dfd25a8047c2c668a44ff57f4
(but you first need to perform the 2 slicings before using the advanced indexing)