Hi guys.
Let’s have the following trivial scenario:
x = torch.randint(0,4, (2,3))
y = torch.randint(0,100, (5, 3))
y[x]
Here, I simply have two position vectors (x) and want to extract the elements of y for each position vector. In this way, y[x] has dimensionality (2x3x3).
Now, I cannot replicate the same results if we add a third dimension, that is, the batch dimension at the beginning. In this case:
x = torch.randint(0,4, (4,2,3))
y = torch.randint(0,100, (4, 5, 3))
y[x]
I expect to broadcast the previous operation for each element in the batch, in order to obtain an array of dimensionality (4x2x3x3), but the output it’s not like that.
Any insight?