Slicing a 3D tensor

Suppose I have a tensor X of shape (B, N, 1) and a tensor of indices that I want to slice the first dimension of X by, call this Y and is of shape (B, 1) (optionally could be squeezed to be of shape B).

Suppose that the first element of Y is j. How do I slice X so that I am returned a tensor Z of shape (B, 1) where the first element of Z is the jth element of X along the first dimension.

An example would be if X is [[[1], [2]], [[3], [4]]] and Y is [[0], [1]] then I would want Z to be [[1], [4]].


You can use torch.gather(...). Link : torch.gather — PyTorch 1.9.1 documentation

1 Like