How to broadcast this slicing to a 3D tensor?

Hi guys.
Let’s have the following trivial scenario:

x = torch.randint(0,4, (2,3))
y = torch.randint(0,100, (5, 3))
y[x]

Here, I simply have two position vectors (x) and want to extract the elements of y for each position vector. In this way, y[x] has dimensionality (2x3x3).
Now, I cannot replicate the same results if we add a third dimension, that is, the batch dimension at the beginning. In this case:

x = torch.randint(0,4, (4,2,3))
y = torch.randint(0,100, (4, 5, 3))
y[x]

I expect to broadcast the previous operation for each element in the batch, in order to obtain an array of dimensionality (4x2x3x3), but the output it’s not like that.

Any insight?

Basically, I can obtain the expected result with this:

torch.cat([y[i][x[i]] for i in range(4)], dim=0).view(4, 2, 3, 3)

But I don’t want to use a for loop

You are indexing the first dimension of y using x. With x as shape of (4, 2, 3), each element of x will point to the first dimension of y and select the corresponding data of (5, 3). That’s why the result size should be (4, 2, 3, 5, 3)