I have a tensor X of size (x, y, z) and a vector V of indices of size (x). I want each element of V to indicate which subtensor along dim=1 should be gathered from X. In other words, gathering V from X should return a size (x, z) tensor K such that K[i, :] == X[i, V[i], :]
I know how to gather from a tensor X’ of size (x, y).
X’.gather(dim=1, V.unsqueeze(1)) would return a tensor K of size (x) such that K[i] = X[i, V[i]]
but how do you extend it to the case with an additional dimension?
X.gather(dim=1, V.unsqueeze(1).unsqueeze(2)) doesn’t seem to work.