Unable to reconstruct unfolded image

I’m having difficulty reconstructing an image from patches after running it through my model. This same approach worked with only one channel input image size (1460,1936), but seems to scramble the order with two channel input and one channel output. I’m sure it’s something very simple that I’m missing, but I can’t seem to find the error in my code.

Example code:

in_channels = 2
temp = np.random.rand(in_channels, 1460, 1936)
normalized = np.zeros([in_channels,temp.shape[1],temp.shape[2]])
for i in range(in_channels):
    normalized[i,:,:] = (temp[i,:,:] - np.mean(temp[i,:,:]))/(np.max(temp[i,:,:]) - np.min(temp[i,:,:]))
x = torch.from_numpy(normalized)
# kernel size
k = 256 
# stride
d = 256
#hpadding
hpad = (k-x.size(1)%k) // 2 
#wpadding
wpad = (k-x.size(2)%k) // 2 
#pad x
x = F.pad(x,(wpad,wpad,hpad,hpad)) #shape torch.Size([2, 1536, 2048])

patches = x.unfold(1, k, d).unfold(2, k, d) #shape torch.Size([2, 6, 8, 256, 256])
unfold_shape = patches.size()
#reshape to (batch,1,h,w)
patches = patches.contiguous().view(-1, in_channels,k, k) #shape torch.Size([48, 2, 256, 256])
temp = torch.empty(patches.shape[0],k,k) #torch.Size([48, 256, 256])
#loop over all patches, feed model predictions back into storage tensor
for i,patch in enumerate(patches):
    temp[i,:,:] = model(patch.view(-1,in_channels,k,k).to(device, dtype = torch.float)).cpu().detach()[0][1]
# Reshape back
patches_orig = temp.view(unfold_shape[1],unfold_shape[2],unfold_shape[3],unfold_shape[4])
output_h = unfold_shape[1] * unfold_shape[3]
output_w = unfold_shape[2] * unfold_shape[4]
patches_orig = patches_orig.permute(0,2,1,3).contiguous()
patches_orig = patches_orig.view(output_h, output_w)

Your code is unfortunately not executable, so I can just make a few suggestions.
I assume this line of code is wrong:

patches = patches.contiguous().view(-1, in_channels,k, k)

patches should have the shape [2, 6, 8, 256, 256] ([channels, patch_h, patch_w, h, w]), so the view call would interleave the data.
Probably patches.permute(1, 2, 0, 3, 4).view(-1, in_channels, k, k) would work.

Thanks! Your suggestion worked, with the addition of .contiguous()

patches.permute(1, 2, 0, 3, 4).contiguous().view(-1, in_channels, k, k)
1 Like