How to extract smaller image patches (3D)?

is it supposed to be applied in __getitem__ part ?

hi @Uwais_Iqbal & @ptrblck I don’t quit understand thetensor formulation in this line:

patches = img.data.unfold(0, 3, 3).unfold(1, 128, 128).unfold(2, 128, 128) 
# torch.Size([1, 40, 29, 3, 128, 128])

Can you explain to me how you would interpret it? Like a batch of 40 where the image is 29 patches? Also, how would you train the model with a 6D tensor?

tensor.unfold creates patches given a kernel size and the stride (similar to im2col).
Most of the time you won’t need to use it manually.
It’s useful e.g. if you would like to experiment with convolution-like operations which are applied on image patches using a sliding window approach.

In this example the first unfold is basically not needed and you could achieve the same using img.permute(1, 2, 0).
Basically you will get 40x29 patches of the shape 3x128x128, which can be further reduced (e.g. summed, matrix multiplied, etc.).

2 Likes

Thanks I am trying to save the extracted patches from the code you have shared. I have tried using indexing and spliting, but seem not to avoid for loop. Is there a better way to do the below, by avoiding the loops altogether?

!wget https://eoimages.gsfc.nasa.gov/images/imagerecords/88000/88094/niobrara_photo_lrg.jpg

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import torchvision.transforms as transforms
from pathlib import Path

patch_size=512
stride=patch_size
pil2tensor = transforms.ToTensor()


file=Path('/content/niobrara_photo_lrg.jpg')
filename=file.stem

rgb_image = pil2tensor(Image.open(file))
patches = rgb_image.data.unfold(0, 3, 3).unfold(1, patch_size, stride).unfold(2, patch_size, stride)
a = list(patches.shape)

x = patches[:,torch.from_numpy(np.arange(0,a[1])),:,:,:,:].split(1, dim=1)
for i in list(np.arange(a[1])):
  y =  x[i][:,:,torch.from_numpy(np.arange(0,a[2])),:,:,:].split(1, dim=2)
  for j in list(np.arange(a[2])):
    save_image(y[j], filename+'-'+str(i)+'-'+str(j)+'.png')


Hi, you used the view method to fold the patches. However, it works when patches are non-overlapping. What should I do if the patches overlap and I want to sum the overlapped part in folding? I looked up the torch.nn.Fold which only supports the 2D batched image-like tensors with shape (N, C, H, W). But I want to fold 3D tensors with shape (N, C, Z, H, W). Are there any solutions to this problem? Thank you!

Sorry, I can’t connect the code website, could you send a new web link?

I’m also interested in this, were you able to find any solution for 3D tensors of shape (N, C, Z, H, W)?

1 Like

Hey,
did you ever figure out if there is a nice pytorch solution to rebuild the 3d image from patches when the patches are overlapping?
It takes forever to rebuild manually.
Thanks in advance!

Hi @ptrblck ,

Hi,

I have two different modality dicom data set ,CT :512 X 512 X 568 and PET:200 X 200 X 426
Both images are preprocessed to get a volumetric data of equal size and pixel spacing .
The dimensions after resampling are 143 X 143 X 284 for both with slice thickness of 3mm.
Now I want to extract patch of 64 X 64 X 64 from my preprocessed data by sliding the window with
overlap size of 18 × 18 × 18 to its neighbouring patches.
My aim is to extract these patches from my volume data and pass it to the U Net based generator model of my cycle GANs.
Can you explain me how can I acheive this in pytorch .

Thank you in advance!

You could use my code snippet and add the 3rd dimension to the unfold operation in order to create the [64, 64, 64] cubes.
Would you like to use a stride of 46 to create the overlapping 18 voxels?

1 Like

Thank you for your response.

Yes I will use stride of 46 to create the overlapping 18 voxels.
I also want to understand that how does one calculate the number of patches based on patch size and stride ?

You could use the shape calculations from the Conv2d docs, which correspond to the used patches.

Also, this doc about conv arithmetic might be useful for the general idea.

Let me know, if you get stuck.

I would keep the volumes as a whole and create the patches after loading the data in the Dataset.
Storing each patch might be beneficial for loading it later, but I assume keeping the correspondence between the patches might be quite tricky in the long run.

For a patch of 64x64 and a stride of 46, this code should work:

x = torch.randn(143, 143, 284)

kernel_size = 64
stride = 46

# Calculate padding to fit the sliding windows
pad0_left = (x.size(0) // stride * stride + kernel_size) - x.size(0)
pad1_left = (x.size(1) // stride * stride + kernel_size) - x.size(1)
pad2_left = (x.size(2) // stride * stride + kernel_size) - x.size(2)

# Calculate symmetric padding
pad0_right = pad0_left // 2 if pad0_left % 2 ==0 else pad0_left // 2 + 1
pad1_right = pad1_left // 2 if pad1_left % 2 ==0 else pad1_left // 2 + 1
pad2_right = pad2_left // 2 if pad2_left % 2 ==0 else pad2_left // 2 + 1

pad0_left = pad0_left // 2
pad1_left = pad1_left // 2
pad2_left = pad2_left // 2

x = F.pad(x, (pad2_left, pad2_right, pad1_left, pad1_right, pad0_left, pad0_right))
x.shape

ret = x.unfold(0, kernel_size, stride).unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)

I would recommend to verify it by visualizing the patches, as I have just verified it on random input.

1 Like

Hi @ptrblck, thanks for your inspiring code. Would you please tell me how can we fold the patches back when we have overlapped patches? Thanks.

For overlapping patches I would recommend to use nn.Fold (after unfolding via nn.Unfold), as it’ll take care of summing the overlapping values. You could probably implement this behavior manually with some permutations and view operations, but nn.Fold already provides this functionality.

I came up onto a similar problem, but one that I would like to avoid solving using Unfold (I’ll explain why).

I have an image, which can be rather large, let’s say 512x512. From this image, I’d like to extract smaller patches at random offsets (middle points). I don’t expect to need anywhere close to all the patches, let’s say I want 5% of all possible patches, so Unfold is wasteful both in terms of memory and computation required (to remove unwanted offsets at the end).

The best solution I’ve been able to think of is manually doing a loop over the midpoints and assembling the slices into an empty (N, C, Px, Py) tensor, however, this isn’t very fast (and creates 1 autograd node per slice I believe).

Is there a way this can be vectorized, assuming the middle points can be completely random and overlapping?

Thanks for your quick reply. However, nn.Fold seems only can work on 2D inputs which is not suitable for 3D inputs. Is there anyother way fold the 3D patches back to the 3D input?

You might take a look at @KFrank’s inclusive sum suggestion. Depending on your actual operation, this might work.

@Katou2 Yeah, you are right and I’m not sure what the best approach would be.
Maybe a custom extension with a sliding window approach would be the fastest way.

I wrote a python code of sliding window manner to fold the extracted patches and compared it with nn.Fold. It is much slower. Does it mean that we need to write a C++ extension of pytorch?