Gather with a third dimension

I have a tensor X of size (x, y, z) and a vector V of indices of size (x). I want each element of V to indicate which subtensor along dim=1 should be gathered from X. In other words, gathering V from X should return a size (x, z) tensor K such that K[i, :] == X[i, V[i], :]

I know how to gather from a tensor X’ of size (x, y).
X’.gather(dim=1, V.unsqueeze(1)) would return a tensor K of size (x) such that K[i] = X[i, V[i]]
but how do you extend it to the case with an additional dimension?

X.gather(dim=1, V.unsqueeze(1).unsqueeze(2)) doesn’t seem to work.

Say you have

X = torch.randn(64,20,100)
V = torch.randint(0, 20, (64,), dtype=torch.int64)

then could spell the above out to

i = torch.arange(0,64, dtype=torch.int64)
K = data[i, V]

Best regards

Thomas

1 Like

Works great, thanks!