How to view the height and width into one dim by grid

Hi there,

Now I have a tensor A with the shape (n, c, h, w). Suppose the first channel of the tensor is like:

[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]]

I want to reshape the tensor into the shape (n, c, h*w) by grids. If the grid is 2x2, then the reshaped tensor should be:

[1, 2, 5, 6,
3, 4, 7, 8,
9, 10, 13, 14,
11, 12, 15, 16]

How can I implement this using existing functions in PyTorch?

try this:

A.view(A.size()[0], A.size()[1], A.size(2)*A.size(3))

Thank you for your kind reply!

However, this definitely does not work as it reshapes the tensor by row instead of the grid. In this case, the output should be like

[1, 2, 3, 4,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16]

which is not what I want.

For a general use case using a grid you could use .fold():

x = torch.arange(1, 17).float().view(1, 1, 4, 4)

kh, kw = 2, 2  # kernel_size
dh, dw = 2, 2  # stride
# get all image windows of size (kh, kw) and stride (dh, dw)
input_windows = x.unfold(2, kh, dh).unfold(3, kw, dw)
output = input_windows.contiguous().view(x.size())
2 Likes

Thank you sooooooo much for your help! This works for me XD

I found the mem cost is a little bit high when plugged this part of code into my own. I guess it’s because the operation contiguous that copies the tensor input_windows twice in the memory. Is there any way to avoid the memory copy?

I don’t think you can avoid the copy, but you can free some memory if you assign the result back to input_windows.

PS: by .fold() I obviously meant .unfold() :wink: