Reshaping windows into image

I’m building a model that does the following: (i) breaks an image up into small square windows, (ii) applies a network to each window, transforming each into an output window of the same size, and (iii) reassembles the output windows into an output image of the same size as the input image. I’m wondering if there’s an efficient way to do the last step without loops.

The stack of windows output by the network has size (batch * num_windows x num_channels x window_side x window_side) and the output image should have size (batch x num_channels x image_side x image_side).

The naive approach of using just a singe .view() doesn’t work here, but I imagine that there is a combination of transposes and views that will do the job. Any tips?

Do the small windows overlap? If so, you would have to think about a method to unite them (summing, mean, etc.).
If not, you could try the following:

channels = 1
h, w = 12, 12
image = torch.randn(1, channels, h, w)

kh, kw = 3, 3 # kernel size
dh, dw = 3, 3 # stride

patches = image.unfold(2, kh, dh).unfold(3, kw, dw)
patches = patches.contiguous().view(-1, channels, kh, kw)

conv = nn.Conv2d(channels, channels, 1, bias=False)
conv.weight.data.fill_(1.)
output = conv(patches)

tmp = output.view(1, 16, kh*kw).permute(0, 2, 1)
im = F.fold(tmp, (h, w), (kh, kw), 1, 0, (dh, dw))

Note that you would have to build PyTorch from source, since F.fold is not yet available in the binary release.
Let me know, if this works for you.

2 Likes

Thank you for the pointers! This code does the trick!

if it possible to do the same thing for a tensor with shape of:

image = torch.randn(1, channels,D, h, w)

where D is depth

Yes, that should also be possible.
If you would like to create the “patches” in D, h, and w, you would need to unfold in the additional D dimension.