In the code below, I would like to extract a submatrix from S
according to the indices stored in idx
.
For the example code, the resulting submatrix s
would have shape (8 x 8).
import torch
def main():
batch_size = 2
channels = 2
height = 4
width = 4
X = torch.arange(start=1, end=batch_size*channels*height*width+1).float()
X = X.reshape(batch_size, channels, height, width)
_, idx = torch.nn.functional.max_pool2d_with_indices(X, kernel_size=(2, 2))
S = torch.arange(start=1, end=(channels*height*width)**2+1).float()
S = S.reshape(channels*height*width, channels*height*width)
# s = ?
if __name__ == "__main__":
main()
Here is a simple example that illustrates the problem better
For the following tensor
X = [[[[1, 0],
[0, 0]],
[[0, 0],
[1, 0]]]]
max_pool2d_with_indices()
returns the following indices
idx = [[[[0]],
[[2]]]]
(For the extraction of the submatrix from S
it would be better treating X
as a vector so that idx
becomes idx = [[[[0]],[[6]]]]
. However, at this point, I don’t know how to compute these indices.)
that I want to use to extract the following submatrix s
s = [[0, 6],
[48, 54]]
from
S = [[0, 1, 2, 3, 4, 5, 6, 7],
[8, 9, 10, 11, 12, 13, 14, 15],
[16, 17, 18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29, 30, 31],
[32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47],
[48, 49, 50, 51, 52, 53, 54, 55],
[56, 57, 58, 59, 60, 61, 62, 63]]
I would be very happy, if someone could give me a hint how to do that.