How to split tensors with overlap and then reconstruct the original tensor?

Hi @ptrblck,

My aim is the same as the OP: unfold a large image into overlapping tiles, then fold them back together, averaging values where there was overlap. I’ve been testing this example with dummy image data, but it doesn’t seem to be returning the overlapped sum.

from skimage.data import astronaut

def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.
    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (image_numpy + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling

    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)


tpi = transforms.ToPILImage()
tform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])
        ])

x = tform(astronaut()).unsqueeze(0)
B, C, W, H = x.shape

kernel_size = 64
stride = 32
patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
print(patches.shape) # [B, C, nb_patches_h, nb_patches_w, kernel_size, kernel_size]

# perform the operations on each patch
# ...

# reshape output to match F.fold input
patches = patches.contiguous().view(B, C, -1, kernel_size*kernel_size)
print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2) 
print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)
print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
# https://pytorch.org/docs/stable/nn.html#torch.nn.Fold

output = F.fold(
    patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]
# Take a look at the input
tpi(tensor2im(x).transpose(1,2,0))

# Take a look at the output
tpi(tensor2im(output).transpose(1,2,0))

The overflow artifacts are expected here and are easily corrected by division with a mask generated by running a tensor of ones through the unfold/fold operation:


…but what we can see is that the patches have not been reassembled as expected. We do get the same shape, but there is substantial scrambling.

Any help on this problem would be much appreciated. Cheers!