How to extract smaller image patches (3D)?

This should be the case and this quick test shows the behavior:

x = torch.arange(24*24*24).view(24, 24, 24)

kernel_size = 3
stride = 1
patches = x.unfold(0, kernel_size, stride).unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
patches = patches.contiguous().view(22, 22, 22, -1)
print(patches[0, 0, 0])
print(x[:3, :3, :3])

The unfold op should already be using views without copies, which can be seen in the strides of patches:

print(patches.stride())
> (576, 24, 1, 576, 24, 1)

If any operation requires contiguous memory, the mentioned error will be raised and you would have to trigger the copy via contiguous().
So while you could create the patches using views only, you wonā€™t be able to perform all operations on these patches, if the memory locations overlap and you have to copy the tensor.

Hello,

Iā€™m trying to figure out how to apply your code to overlapping 3D patches, but canā€™t get it to work. Do you have any idea how the code needs to be changed to work for this example?

Hi all! You can use samplers in TorchIO for all this stuff. You can extract 2D, 3D or 4D patches from medical images randomly (for training) or densely (for testing). Hereā€™s a little snippet: Creating non overlapping patches and reconstructing image back from the patches

Thereā€™s also support for overlapping patches.

Hi,
I have create a patch extracter and combiner from torch.nn.functional api

code is below,
note that extract_patches_3d and extract_patches_3ds have same output, the latter is just shorter.

Also note that when combining patches that overlap, the overlapping elements will be summed.

import torch

def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding, padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)

    channels = x.shape[1]

    x = torch.nn.functional.pad(x, padding)
    # (B, C, D, H, W)
    x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
    return x

def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]

    d_dim_in = x.shape[2]
    h_dim_in = x.shape[3]
    w_dim_in = x.shape[4]
    d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
    
    # (B, C, D, H, W)
    x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
    # (B, C, D, H * W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
    # (B, C * kernel_size[0], d_dim_out * H * W)

    x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
    # (B, C * kernel_size[0] * d_dim_out, H, W)

    x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
    # (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)

    x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
    # (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  

    x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
    # (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])

    return x



def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size, kernel_size)
    if isinstance(padding, int):
        padding = (padding, padding, padding)
    if isinstance(stride, int):
        stride = (stride, stride, stride)
    if isinstance(dilation, int):
        dilation = (dilation, dilation, dilation)

    def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
        dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
        return dim_out

    channels = x.shape[1]
    d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
    d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
    h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
    w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
    # print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)

    x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
    # (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])

    x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
    # (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)

    x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
    # (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)

    x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
    # (B, C * kernel_size[0] * d_dim_in, H, W)

    x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
    # (B, C * kernel_size[0], d_dim_in * H * W)

    x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
    # (B, C, D, H * W)
    
    x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
    # (B, C, D, H, W)

    return x

a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
# b = extract_patches_3d(a, 2, padding=1, stride=2)
b = extract_patches_3ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_3d(b, 2, (2,2,2,4,4), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))

Output:

torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
torch.Size([36, 2, 2, 2, 2])
tensor([[[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   1.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  2.,   3.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  4.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   5.],
           [  0.,   9.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  6.,   7.],
           [ 10.,  11.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  8.,   0.],
           [ 12.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  13.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 14.,  15.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 16.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  17.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 18.,  19.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 20.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  21.],
           [  0.,  25.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 22.,  23.],
           [ 26.,  27.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 24.,   0.],
           [ 28.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  29.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 30.,  31.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 32.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  33.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 34.,  35.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 36.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  37.],
           [  0.,  41.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 38.,  39.],
           [ 42.,  43.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 40.,   0.],
           [ 44.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  45.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 46.,  47.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 48.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  49.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 50.,  51.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 52.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  53.],
           [  0.,  57.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 54.,  55.],
           [ 58.,  59.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 56.,   0.],
           [ 60.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  61.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 62.,  63.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 64.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  65.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 66.,  67.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 68.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  69.],
           [  0.,  73.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 70.,  71.],
           [ 74.,  75.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 72.,   0.],
           [ 76.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,  77.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[ 78.,  79.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[ 80.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,  81.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [ 82.,  83.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [ 84.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,  85.],
           [  0.,  89.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 86.,  87.],
           [ 90.,  91.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 88.,   0.],
           [ 92.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,  93.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[ 94.,  95.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[ 96.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,  97.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [ 98.,  99.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [100.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[  0., 101.],
           [  0., 105.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[102., 103.],
           [106., 107.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[104.,   0.],
           [108.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[  0., 109.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0.,   0.]],

          [[110., 111.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [  0.,   0.]],

          [[112.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [  0., 113.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0.,   0.],
           [114., 115.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0.,   0.],
           [116.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[  0., 117.],
           [  0., 121.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[118., 119.],
           [122., 123.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[120.,   0.],
           [124.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[  0., 125.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]],



        [[[[126., 127.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]],


         [[[128.,   0.],
           [  0.,   0.]],

          [[  0.,   0.],
           [  0.,   0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
           [  5.,   6.,   7.,   8.],
           [  9.,  10.,  11.,  12.],
           [ 13.,  14.,  15.,  16.]],

          [[ 17.,  18.,  19.,  20.],
           [ 21.,  22.,  23.,  24.],
           [ 25.,  26.,  27.,  28.],
           [ 29.,  30.,  31.,  32.]]],


         [[[ 33.,  34.,  35.,  36.],
           [ 37.,  38.,  39.,  40.],
           [ 41.,  42.,  43.,  44.],
           [ 45.,  46.,  47.,  48.]],

          [[ 49.,  50.,  51.,  52.],
           [ 53.,  54.,  55.,  56.],
           [ 57.,  58.,  59.,  60.],
           [ 61.,  62.,  63.,  64.]]]],



        [[[[ 65.,  66.,  67.,  68.],
           [ 69.,  70.,  71.,  72.],
           [ 73.,  74.,  75.,  76.],
           [ 77.,  78.,  79.,  80.]],

          [[ 81.,  82.,  83.,  84.],
           [ 85.,  86.,  87.,  88.],
           [ 89.,  90.,  91.,  92.],
           [ 93.,  94.,  95.,  96.]]],


         [[[ 97.,  98.,  99., 100.],
           [101., 102., 103., 104.],
           [105., 106., 107., 108.],
           [109., 110., 111., 112.]],

          [[113., 114., 115., 116.],
           [117., 118., 119., 120.],
           [121., 122., 123., 124.],
           [125., 126., 127., 128.]]]]])
tensor(True)

I am new to CV. How do I iterate over the patches and view the image patches?

Something like this should work:

x = torch.arange(6*6).view(6, 6)
print(x)
# tensor([[ 0,  1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10, 11],
#         [12, 13, 14, 15, 16, 17],
#         [18, 19, 20, 21, 22, 23],
#         [24, 25, 26, 27, 28, 29],
#         [30, 31, 32, 33, 34, 35]])

kernel_size = 3
stride = 3
out = x.unfold(0, kernel_size, stride).unfold(1, kernel_size, stride)
print(out)
# tensor([[[ 0,  1,  2],
#          [ 6,  7,  8],
#          [12, 13, 14]],

#         [[ 3,  4,  5],
#          [ 9, 10, 11],
#          [15, 16, 17]],

#         [[18, 19, 20],
#          [24, 25, 26],
#          [30, 31, 32]],

#         [[21, 22, 23],
#          [27, 28, 29],
#          [33, 34, 35]]])

out = out.contiguous().view(-1, kernel_size, kernel_size)
# torch.Size([4, 3, 3])

for out_ in out:
    # visualize out_ here

I was not able to produce an interpretable patch from an image. Hereā€™s my code after I saw your code.

img_path = './data/cat.jpg'
plt.figure(num=None, figsize=(16, 12), dpi=80)
plt.subplot(2, 4, 1)
img = cv2.imread(img_path)
img = cv2.resize(img, (224, 224))
plt.imshow(img)
plt.subplot(2, 4, 0 + 2)

transform = transforms.ToTensor()
# Convert the image to PyTorch tensor
whole_img = transform(img)
#size 3, 224, 224

patches = whole_img.unfold(1, 32, 32).unfold(2, 32, 32)
patches = patches.permute(1, 2, 3, 4, 0) # from [3, 7, 7, 32, 32] to [7, 7, 32, 32, 3]

patches = patches.contiguous().view(-1, 32, 32, 3)
# size 49, 32, 32, 3

plt.imshow(patches[0].detach().numpy()) # view the first patch only

plt.savefig('patch-result.jpg')

Your code works for me:

img_path = 'img.jpeg'
plt.figure(num=None, figsize=(16, 12), dpi=80)
plt.subplot(2, 4, 1)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
plt.imshow(img)


transform = transforms.ToTensor()
# Convert the image to PyTorch tensor
whole_img = transform(img)
#size 3, 224, 224

patches = whole_img.unfold(1, 32, 32).unfold(2, 32, 32)
patches = patches.permute(1, 2, 3, 4, 0) # from [3, 7, 7, 32, 32] to [7, 7, 32, 32, 3]

patches = patches.contiguous().view(-1, 32, 32, 3)
# size 49, 32, 32, 3

fix, axarr = plt.subplots(7, 7)
for idx, ax in enumerate(axarr.reshape(-1)):
    ax.imshow(patches[idx].detach().numpy()) # view the first patch only

Outputs:
image

image

1 Like

I was printing only the white blank space (first patch) :smiley: I also noticed the lines you added.

Thanks, @ptrblck!