Get batched tensor elements by index

I want to use the indices returned by max_pool2d_with_indices(x) to extract elements from a tensor m with the same dimensions as x. However, the returned values are wrong. Here is a minimal working example:

import torch

def main():

    x = [[[[0, 0, 0, 0],
           [0, 1, 1, 0],
           [0, 1, 1, 0],
           [0, 0, 0, 0]],

         [[1, 0, 0, 1],
          [0, 0, 0, 0],
          [0, 0, 0, 0],
          [1, 0, 0, 1]]]]

    x = torch.tensor(x).float()

    m = [[[[10, 11, 12, 13],
           [14, 15, 16, 17],
           [18, 19, 20, 21],
           [22, 23, 24, 25]],

          [[26, 27, 28, 29],
           [30, 31, 32, 33],
           [34, 35, 36, 37],
           [38, 39, 40, 41]]]]

    m = torch.tensor(m).float()

    y, idx = torch.nn.functional.max_pool2d_with_indices(x, kernel_size=(2, 2))

    m_select = torch.take(m, idx)
    print(m_select)


if __name__ == "__main__":
    main()

I want to m_select to be

tensor([[[[15., 16.],
          [19., 20.]],

         [[26., 29.],
          [38., 41.]]]])

but currently I get

tensor([[[[15., 16.],
          [19., 20.]],

         [[10., 13.],
          [22., 25.]]]])

Meaning, I want the elements in m where the activations in x are the highest. But currently only elements from the first feature map are returned by torch.take(). How can I get the desired output?

You could flatten the tensors and use gather to get the desired output:

out = torch.gather(torch.flatten(m, 2), 2, torch.flatten(idx, 2)).view(idx.size())
print(out)
>  tensor([[[[15., 16.],
            [19., 20.]],

           [[26., 29.],
            [38., 41.]]]])
1 Like

Thank you!

Would it also be possible to use torch.take() by modifying idx in such a way that the indices do not start for every feature map from 0? Is there a simple way to do that?

take would treat the tensor as 1D, so you could try to add the offsets to the indices so that the values can be accessed in a sequential/flattened manner.

1 Like