Pooling using idices from another max pooling

Inspired by @ptrblck’s answer in MaxPool2d indexing order :

import torch, torch.nn as nn

def retrieve_elements_from_indices(tensor, indices):
    flattened_tensor = tensor.flatten(start_dim=2)
    output = flattened_tensor.gather(dim=2, index=indices.flatten(start_dim=2)).view_as(indices)
    return output


# define data variables
maxpool = nn.MaxPool2d(2,2, return_indices=True)
data1 = torch.randn(1,2,4,4)
data2 = torch.randn(1,2,4,4)

# maxpool data1
output1, indices = maxpool(data1)

# retrieve corresponding elements from data2 according to indices
output2 = retrieve_elements_from_indices(data2, indices)
9 Likes