Creating non overlapping patches and reconstructing image back from the patches

Hello, I am trying to use unfold function for creating the patches but I have not able understand how to use it for my case.

I have image size of [284,143,143],it is a 3D volumetric medical image. 284 are the number of slices here.
Now if I want to get a non overlapping 2D patches of size 128 * 128. and rescontruct it back to original image.

How can I acheive this

This post gives an example how padding and unfold can be applied to create non-overlapping patches as well as how to reshape these patches back to the original input tensor.
Note that for your shapes you would need to pad the input, since the shapes are not factors of the kernel size.

How can I change my padding here , since I want non overlapping pixels.

def extract_patches(img, kernel_size=128, stride=15):
x = F.pad(img,(img.size(2) % kernel_size // 2, img.size(2) % kernel_size // 2, img.size(1) % kernel_size // 2, img.size(1) % kernel_size // 2))
print(x.shape)
ret = x.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride).reshape(-1,128,128)
print(ret.shape)
return ret

But when I try to visualize
plt.imshow(patch[10],cmap=‘gray’)
plt.show()
The left most and top borders are not correcttly padded

[143 143 284] before paading, actual image size
torch.Size([284, 157, 157]) after padding
torch.Size([1136, 128, 128]) patch size

I just added your input shape and replaced the kernel sizes as well as strides in my code snippet and it seems to work:

x = torch.randn(1, 284, 143, 143)
kc, kh, kw = 128, 128, 128  # kernel size
dc, dh, dw = 128, 128, 128  # stride
# Pad to multiples of 128
x = F.pad(x, (x.size(2)%kw // 2, x.size(2)%kw // 2,
              x.size(1)%kh // 2, x.size(1)%kh // 2,
              x.size(0)%kc // 2, x.size(0)%kc // 2))


patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.contiguous().view(-1, kc, kh, kw)
print(patches.shape)

# Reshape back
patches_orig = patches.view(unfold_shape)
output_c = unfold_shape[1] * unfold_shape[4]
output_h = unfold_shape[2] * unfold_shape[5]
output_w = unfold_shape[3] * unfold_shape[6]
patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
patches_orig = patches_orig.view(1, output_c, output_h, output_w)

# Check for equality
print((patches_orig == x[:, :output_c, :output_h, :output_w]).all())
> tensor(True)

Are you seeing any issues?

1 Like

Hi, @CS.Enthu. You could also use the grid sampler and aggregator in TorchIO if you don’t want to worry about low-level issues:

import torch
from torch.utils.data import DataLoader
import torchio as tio

t = torch.rand(1, 284, 143, 143)
subject = tio.Subject(image=tio.ScalarImage(tensor=t))
sampler = tio.data.GridSampler(subject, (1, 128, 128))
len(sampler)  # 1136
loader = DataLoader(sampler, batch_size=64)
len(loader)  # 18
aggregator = tio.data.GridAggregator(sampler)
for batch in loader:
    aggregator.add_batch(batch['image']['data'], batch['location'])
output = aggregator.get_output_tensor()
output.shape  # torch.Size([1, 284, 143, 143])
1 Like

Hi,
I have create a patch extracter and combiner from torch.nn.functional api

code is below,
note that extract_patches_3d and extract_patches_3ds have same output, the latter is just shorter.

Also note that when combining patches that overlap, the overlapping elements will be summed.

import torch

def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, D, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
    return x

def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]

    d_dim_in = x.shape[2]
    h_dim_in = x.shape[3]
    w_dim_in = x.shape[4]
    d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
    
    # (B, C, D, H, W)
    x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
    # (B, C, D, H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
    # (B, C * kernel_size[0], d_dim_out * H * W)

    x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
    # (B, C * kernel_size[0] * d_dim_out, H, W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
    # (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)

    x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
    # (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  

    x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

    return x



def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
    d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)

    x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
    # (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)

    x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)

    x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
    # (B, C * kernel_size[0] * d_dim_in, H, W)

    x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0], d_dim_in * H * W)

    x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C, D, H * W)
    
    x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
    # (B, C, D, H, W)

    return x

a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
# b = extract_patches_3d(a, 2, padding=1, stride=2)
b = extract_patches_3ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_3d(b, 2, (2,2,2,4,4), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))

Output:

torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
torch.Size([36, 2, 2, 2, 2])
tensor([[[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   1.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  2.,   3.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  4.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   5.],
           [  0.,   9.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  6.,   7.],
           [ 10.,  11.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  8.,   0.],
           [ 12.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  13.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 14.,  15.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 16.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  17.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 18.,  19.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 20.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  21.],
           [  0.,  25.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 22.,  23.],
           [ 26.,  27.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 24.,   0.],
           [ 28.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  29.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 30.,  31.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 32.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  33.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 34.,  35.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 36.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  37.],
           [  0.,  41.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 38.,  39.],
           [ 42.,  43.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 40.,   0.],
           [ 44.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  45.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 46.,  47.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 48.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  49.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 50.,  51.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 52.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  53.],
           [  0.,  57.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 54.,  55.],
           [ 58.,  59.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 56.,   0.],
           [ 60.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  61.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 62.,  63.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 64.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  65.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 66.,  67.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 68.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  69.],
           [  0.,  73.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 70.,  71.],
           [ 74.,  75.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 72.,   0.],
           [ 76.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  77.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 78.,  79.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 80.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  81.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 82.,  83.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 84.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  85.],
           [  0.,  89.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 86.,  87.],
           [ 90.,  91.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 88.,   0.],
           [ 92.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  93.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 94.,  95.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 96.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  97.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 98.,  99.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [100.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0., 101.],
           [  0., 105.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[102., 103.],
           [106., 107.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[104.,   0.],
           [108.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0., 109.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[110., 111.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[112.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0., 113.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [114., 115.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [116.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0., 117.],
           [  0., 121.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[118., 119.],
           [122., 123.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[120.,   0.],
           [124.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0., 125.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[126., 127.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[128.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
tensor(True) 
# ignore