Hi there,
Now I have a tensor A with the shape (n, c, h, w). Suppose the first channel of the tensor is like:
[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]]
I want to reshape the tensor into the shape (n, c, h*w) by grids. If the grid is 2x2, then the reshaped tensor should be:
[1, 2, 5, 6,
3, 4, 7, 8,
9, 10, 13, 14,
11, 12, 15, 16]
How can I implement this using existing functions in PyTorch?