I have an 3d input tensor with size N x H x W. I’d like to index it along dimension 1 with a 1d tensor of length N, whose entries are of type long in {0, 1,…,H}.The returned tensor would have size N x 1 x W. Can this be done in Pytorch?
Thanks!
I have an 3d input tensor with size N x H x W. I’d like to index it along dimension 1 with a 1d tensor of length N, whose entries are of type long in {0, 1,…,H}.The returned tensor would have size N x 1 x W. Can this be done in Pytorch?
Thanks!
Ah, I figured it out. I wasn’t using gather
properly before. To get it working, I started calling gather
with a 3d index tensor of size N x 1 x W. I use expand
to repeat the index W times.