Extract values from tensor by indices of max-pooling

(Philip Meier) #1

Suppose I have two tensors x and y of the same size BxCxHxW. I want to extract the values of x that are picked by a max-pooling from y. Since the builtin max_pool2d only returns the spatial indices they have to be converted before they can be used within take().

import torch.nn.functional as F
_, spatidcs = F.max_pool2d(y, *, return_indices=True)
z = torch.take(x, spatidcs_to_idcs(y, spatidcs))

I came up with the following implementation for the conversion:

def spatidcs_to_idcs(input, spatidcs):
    batch_size, num_channels, height, width = input.size()
    batch_offset = torch.arange(0, batch_size, 1,
                                dtype=spatidcs.dtype, device=input.device)
    channel_offset = torch.arange(0, num_channels, 1,
                                  dtype=spatidcs.dtype, device=input.device)
    offset = (batch_offset.view(-1, 1, 1, 1) * num_channels +
              channel_offset.view(1, -1, 1, 1)) *  height * width
    return spatidcs + offset

This works correctly but I think this is quite inefficient, since I create two tensors with every call of this function. Does someone know a better solution?