MaxPool2d indexing order

I’m trying to understand how the indices of MaxPool2d work.
I made a simple example where I max-pool a 4x4 region with pools of size 2 and stride 2.

import torch
import torch.nn as nn
import torch.nn.functional as fn

# The two input planes for maxpooling
x1 = torch.Tensor(
    [[1, 2, 4, 5],
     [4, 1, 2, 1],
     [0, 5, 0, 0],
     [1, 1, 1, 1]]
)

x2 = torch.Tensor(
    [[2, 2, 2, 2],
     [4, 0, 1, 4],
     [5, 2, 2, 1],
     [0, 1, 4, 1],
    ]
)

# Batch size: 1
# Channels: 2
input = torch.stack((x1, x2), 0).unsqueeze(0)

pool = nn.MaxPool2d(2, 2, return_indices=True)

o, i = pool(input)

print(o)
print(i)

What I get is

Variable containing:
(0 ,0 ,.,.) = 
  4  5
  5  1

(0 ,1 ,.,.) = 
  4  4
  5  4
[torch.FloatTensor of size 1x2x2x2]

Variable containing:
(0 ,0 ,.,.) = 
   4   3
   9  14

(0 ,1 ,.,.) = 
   4   7
   8  14
[torch.LongTensor of size 1x2x2x2]

The first output is the result of the max-pool and it’s correct.
The second output are the indices of the max values.
I don’t understand how these are aligned with the dimensions.

If I select the elements in the input tensor with these indices, I expect to get the result from the max-pooling. However:

tmp = torch.index_select(input.view(-1), 0, i.data.view(-1))
print(tmp.view(1, 2, 2, 2))

This gives me the output

(0 ,0 ,.,.) = 
  4  5
  5  1

(0 ,1 ,.,.) = 
  4  1
  0  1
[torch.FloatTensor of size 1x2x2x2]

which is not the same as the output of the max-pooling.

My question is, how can I use the indices returned by MaxPool2d to “manually” select the max elements in the input tensor?

2 Likes

Did you ever figure this out? I’m trying to understand this as well.

You could flatten the input tensor and use gather:

x = torch.flatten(input, 2)
o2 = torch.gather(x, 2, torch.flatten(i, 2)).view(o.size())
print((o==o2).all())
> tensor(1, dtype=torch.uint8)