Using max_pooling indices for tensor subsampling/pooling in PyTorch

Here’s a simple snippet illustrating how you could recover the same values using the original input and indices, which should also be applicable to sampling e1 in your case:

>>> y1_pooled2 = torch.gather(torch.flatten(y1, -2), 2, torch.flatten(idx, -2)).reshape(1, 32, 14, 14)
>>> torch.allclose(y1_pooled, y1_pooled2)
True
>>>

You can think of the index as indexing the flattened spatial dimensions of the input for the max value in each window.

1 Like