Using max_pooling indices for tensor subsampling/pooling in PyTorch

Given two tensors of the same size, how can we use the indices obtained from max_pooling one tensor to subsample or pool the other tensor in PyTorch? When attempting this approach,

import torch

# Assuming y1 and e1 have the same dimensions
y1 = torch.randn(1, 32, 28, 28)
e1 = torch.randn(1, 32, 28, 28)

# Max-pooling
mxp = torch.nn.MaxPool2d(2, stride=2, return_indices=True)
y1_pooled, idx = mxp(y1)

# Subsample e1 using idx
e1 = e1[idx]

I get the error:

Traceback (most recent call last):
    e1 = e1[idx]
IndexError: index 1 is out of bounds for dimension 0 with size 1

Here’s a simple snippet illustrating how you could recover the same values using the original input and indices, which should also be applicable to sampling e1 in your case:

>>> y1_pooled2 = torch.gather(torch.flatten(y1, -2), 2, torch.flatten(idx, -2)).reshape(1, 32, 14, 14)
>>> torch.allclose(y1_pooled, y1_pooled2)
True
>>>

You can think of the index as indexing the flattened spatial dimensions of the input for the max value in each window.

1 Like