Index a tensor using two tensors

Hi, I have an N x IC x IH x IW input, and I want to index the IH and IW dimensions using two N x M tensors to get a N x IC x M output. How to do it?

Advanced indexing (links to the NumPy reference, but PyTorch wants to be compatible here) should do the trick:

values = torch.arange(2 * 3 * 5 * 5).view(2, 3, 5, 5)
idx_x = torch.randint(0, 5, (4,))
idx_y = torch.randint(0, 5, (4,))
values[:, :, idx_x, idx_y]

Best regards


Hi, thanks for replying, but in my case idx_x and idx_y have to be 2x4 instead of just 4, meaning each example in the batch has its own indices. Is it still possible with advanced indexing? Or do we have to use a for-loop to loop through the examples in a batch?