Suppose I have two tensors x
and y
of the same size BxCxHxW
. I want to extract the values of x
that are picked by a max-pooling from y
. Since the builtin max_pool2d
only returns the spatial indices they have to be converted before they can be used within take()
.
import torch.nn.functional as F
_, spatidcs = F.max_pool2d(y, *, return_indices=True)
z = torch.take(x, spatidcs_to_idcs(y, spatidcs))
I came up with the following implementation for the conversion:
def spatidcs_to_idcs(input, spatidcs):
batch_size, num_channels, height, width = input.size()
batch_offset = torch.arange(0, batch_size, 1,
dtype=spatidcs.dtype, device=input.device)
channel_offset = torch.arange(0, num_channels, 1,
dtype=spatidcs.dtype, device=input.device)
offset = (batch_offset.view(-1, 1, 1, 1) * num_channels +
channel_offset.view(1, -1, 1, 1)) * height * width
return spatidcs + offset
This works correctly but I think this is quite inefficient, since I create two tensors with every call of this function. Does someone know a better solution?