MaxPool2d indexing order

You could flatten the input tensor and use gather:

x = torch.flatten(input, 2)
o2 = torch.gather(x, 2, torch.flatten(i, 2)).view(o.size())
print((o==o2).all())
> tensor(1, dtype=torch.uint8)