I want to use the indices returned by max_pool2d_with_indices(x)
to extract elements from a tensor m
with the same dimensions as x
. However, the returned values are wrong. Here is a minimal working example:
import torch
def main():
x = [[[[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]],
[[1, 0, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]]]]
x = torch.tensor(x).float()
m = [[[[10, 11, 12, 13],
[14, 15, 16, 17],
[18, 19, 20, 21],
[22, 23, 24, 25]],
[[26, 27, 28, 29],
[30, 31, 32, 33],
[34, 35, 36, 37],
[38, 39, 40, 41]]]]
m = torch.tensor(m).float()
y, idx = torch.nn.functional.max_pool2d_with_indices(x, kernel_size=(2, 2))
m_select = torch.take(m, idx)
print(m_select)
if __name__ == "__main__":
main()
I want to m_select
to be
tensor([[[[15., 16.],
[19., 20.]],
[[26., 29.],
[38., 41.]]]])
but currently I get
tensor([[[[15., 16.],
[19., 20.]],
[[10., 13.],
[22., 25.]]]])
Meaning, I want the elements in m
where the activations in x
are the highest. But currently only elements from the first feature map are returned by torch.take()
. How can I get the desired output?