Here’s a simple snippet illustrating how you could recover the same values using the original input and indices, which should also be applicable to sampling e1
in your case:
>>> y1_pooled2 = torch.gather(torch.flatten(y1, -2), 2, torch.flatten(idx, -2)).reshape(1, 32, 14, 14)
>>> torch.allclose(y1_pooled, y1_pooled2)
True
>>>
You can think of the index as indexing the flattened spatial dimensions of the input for the max value in each window.