Hi @ptrblck,
My aim is the same as the OP: unfold a large image into overlapping tiles, then fold them back together, averaging values where there was overlap. I’ve been testing this example with dummy image data, but it doesn’t seem to be returning the overlapped sum.
from skimage.data import astronaut
def tensor2im(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (image_numpy + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
tpi = transforms.ToPILImage()
tform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
x = tform(astronaut()).unsqueeze(0)
B, C, W, H = x.shape
kernel_size = 64
stride = 32
patches = x.unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
print(patches.shape) # [B, C, nb_patches_h, nb_patches_w, kernel_size, kernel_size]
# perform the operations on each patch
# ...
# reshape output to match F.fold input
patches = patches.contiguous().view(B, C, -1, kernel_size*kernel_size)
print(patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2)
print(patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel_size*kernel_size, -1)
print(patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
# https://pytorch.org/docs/stable/nn.html#torch.nn.Fold
output = F.fold(
patches, output_size=(H, W), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]
# Take a look at the input
tpi(tensor2im(x).transpose(1,2,0))
# Take a look at the output
tpi(tensor2im(output).transpose(1,2,0))
The overflow artifacts are expected here and are easily corrected by division with a mask generated by running a tensor of ones through the unfold/fold operation:
…but what we can see is that the patches have not been reassembled as expected. We do get the same shape, but there is substantial scrambling.
Any help on this problem would be much appreciated. Cheers!


