Indexing a 3d tensor with a 1d tensor

I have an 3d input tensor with size N x H x W. I’d like to index it along dimension 1 with a 1d tensor of length N, whose entries are of type long in {0, 1,…,H}.The returned tensor would have size N x 1 x W. Can this be done in Pytorch?


Ah, I figured it out. I wasn’t using gather properly before. To get it working, I started calling gather with a 3d index tensor of size N x 1 x W. I use expand to repeat the index W times.