# MaxPool2d indexing order

I’m trying to understand how the indices of MaxPool2d work.
I made a simple example where I max-pool a 4x4 region with pools of size 2 and stride 2.

``````import torch
import torch.nn as nn
import torch.nn.functional as fn

# The two input planes for maxpooling
x1 = torch.Tensor(
[[1, 2, 4, 5],
[4, 1, 2, 1],
[0, 5, 0, 0],
[1, 1, 1, 1]]
)

x2 = torch.Tensor(
[[2, 2, 2, 2],
[4, 0, 1, 4],
[5, 2, 2, 1],
[0, 1, 4, 1],
]
)

# Batch size: 1
# Channels: 2
input = torch.stack((x1, x2), 0).unsqueeze(0)

pool = nn.MaxPool2d(2, 2, return_indices=True)

o, i = pool(input)

print(o)
print(i)
``````

What I get is

``````Variable containing:
(0 ,0 ,.,.) =
4  5
5  1

(0 ,1 ,.,.) =
4  4
5  4
[torch.FloatTensor of size 1x2x2x2]

Variable containing:
(0 ,0 ,.,.) =
4   3
9  14

(0 ,1 ,.,.) =
4   7
8  14
[torch.LongTensor of size 1x2x2x2]
``````

The first output is the result of the max-pool and it’s correct.
The second output are the indices of the max values.
I don’t understand how these are aligned with the dimensions.

If I select the elements in the input tensor with these indices, I expect to get the result from the max-pooling. However:

``````tmp = torch.index_select(input.view(-1), 0, i.data.view(-1))
print(tmp.view(1, 2, 2, 2))
``````

This gives me the output

``````(0 ,0 ,.,.) =
4  5
5  1

(0 ,1 ,.,.) =
4  1
0  1
[torch.FloatTensor of size 1x2x2x2]
``````

which is not the same as the output of the max-pooling.

My question is, how can I use the indices returned by MaxPool2d to “manually” select the max elements in the input tensor?

2 Likes

Did you ever figure this out? I’m trying to understand this as well.

You could flatten the `input` tensor and use `gather`:

``````x = torch.flatten(input, 2)
o2 = torch.gather(x, 2, torch.flatten(i, 2)).view(o.size())
print((o==o2).all())
> tensor(1, dtype=torch.uint8)
``````