Extract elements from matrix given indices from max_pool2d

In the code below, I would like to extract a submatrix from S according to the indices stored in idx.
For the example code, the resulting submatrix s would have shape (8 x 8).

import torch

def main():
    batch_size = 2
    channels = 2
    height = 4
    width = 4

    X = torch.arange(start=1, end=batch_size*channels*height*width+1).float()
    X = X.reshape(batch_size, channels, height, width)

    _, idx = torch.nn.functional.max_pool2d_with_indices(X, kernel_size=(2, 2))

    S = torch.arange(start=1, end=(channels*height*width)**2+1).float()
    S = S.reshape(channels*height*width, channels*height*width)

    # s = ?

if __name__ == "__main__":
    main()

Here is a simple example that illustrates the problem better

For the following tensor

X = [[[[1, 0],
       [0, 0]],

      [[0, 0],
       [1, 0]]]]

max_pool2d_with_indices() returns the following indices

idx = [[[[0]],
        [[2]]]]

(For the extraction of the submatrix from S it would be better treating X as a vector so that idx becomes idx = [[[[0]],[[6]]]]. However, at this point, I don’t know how to compute these indices.)

that I want to use to extract the following submatrix s

s = [[0, 6],
     [48, 54]]

from

S = [[0,  1,  2,  3,  4,  5,  6,  7],
     [8,  9, 10, 11, 12, 13, 14, 15],
     [16, 17, 18, 19, 20, 21, 22, 23],
     [24, 25, 26, 27, 28, 29, 30, 31],
     [32, 33, 34, 35, 36, 37, 38, 39],
     [40, 41, 42, 43, 44, 45, 46, 47],
     [48, 49, 50, 51, 52, 53, 54, 55],
     [56, 57, 58, 59, 60, 61, 62, 63]]

I would be very happy, if someone could give me a hint how to do that.

Probably this might help