I have a 4D tensor x, and a 2D index tensor i of shape (N, 3), where i[n] is an index over the first 3 dimensions of x. I would like to extract the x values at these indices. After some trial and error, I found that the following does what I want:

result = x[i[:, 0], i[:, 1], i[:, 2]]

I was wondering if there was a better way to do so. I looked at torch.gather and torch.index_select, but they seem to be for 1D indices. Any ideas?