How to extract smaller image patches (3D)?

Best way to extract smaller image patches(3D)?
First step, I would like to read 10 three-dimentional data with size of (H, W, S) and then downsample these data to (H/2, W/2, S/2).
Second step, I want to design a sliding window to extract patches with size of (64, 64, 64) from the above images.
Are there some examples about this processing in Dataloader?

2 Likes

Maybe you can draw some ideas from here https://github.com/pytorch/pytorch/issues/3387

hi, Naruto, thank you for your sharing. But I did not find ideas about how to solve my problem. do you have any idea about that?

For the second use case you could use Tensor.unfold:

S = 128 # channel dim
W = 256 # width
H = 256 # height
batch_size = 10

x = torch.randn(batch_size, S, W, H)

size = 64 # patch size
stride = 64 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)
> torch.Size([10, 2, 4, 4, 64, 64, 64])

patches now containes [2, 4, 4] patches of size [64, 64, 64].

For the first use case:
You could use pooling operators like nn.MaxPool3d and nn.AvgPool3d.

Here is an example using the functional API:

import torch.nn.functional as F
x_half = F.max_pool3d(Variable(x), kernel_size=2, stride=2)
13 Likes

Thanks a million for your examples. It should work. By the way, how shall I write __getitem__ if I get the torch.Size([10, 2, 4, 4, 64, 64, 64]) ?

It depends on your model.
What input does your model expect?
If you want to use the 64x64x64 patches, you could reshape them into the batch dimension:

x = x.view(-1, 64, 64, 64)
1 Like

yeah. my input size is 646464. However, there is a problem of out of memory if I use x = x.view(-1, 64, 64, 64) to torch.Size([320, 64, 64, 64]). It exceeds memory of my GPU.

Try to lower your batch size and run it again.

I have a 2D version here, but my version does not pad zeros or discard leftover rows and columns (like tensorflow pad mode “SAME” and “VALID”. My version just adds the leftover patches, there are some overlapping for them though.

def extract_patches_2D(img,size):
    patch_H, patch_W = min(img.size(2),size[0]),min(img.size(3),size[1])
    patches_fold_H = img.unfold(2, patch_H, patch_H)
    if(img.size(2) % patch_H != 0):
        patches_fold_H = torch.cat((patches_fold_H,img[:,:,-patch_H:,].permute(0,1,3,2).unsqueeze(2)),dim=2)
    patches_fold_HW = patches_fold_H.unfold(3, patch_W, patch_W)
    if(img.size(3) % patch_W != 0):
        patches_fold_HW = torch.cat((patches_fold_HW,patches_fold_H[:,:,:,-patch_W:,:].permute(0,1,2,4,3).unsqueeze(3)),dim=3)
    patches = patches_fold_HW.permute(0,2,3,1,4,5).reshape(-1,img.size(1),patch_H,patch_W)
    return patches

import matplotlib.pyplot as plt
from torchvision.utils import make_grid

patches = extract_patches_2D(img,size)
print(patches.shape)
nrow = int(np.ceil(float(img.size(3))/size[1]))
show_patches = make_grid(patches,nrow=nrow).permute(1,2,0).numpy()
plt.imshow(show_patches)
plt.show()

I have a newer version code, including extract patches from batch of images and also recover images from extracted patches. It also has the ability to extract overlapping patches and recover from them.

5 Likes

Hi @ptrblck ,

Thanks very much for your proposition. Actually, I fell into the same situation: extracting 3D patch from MRI.

Here is the original dimension of my MRI: 169 * 208 * 179. I set the patch_size = 21, and stride_size = 10. Using your code: patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride), I finally got pathces as shape:
[15, 19, 16, 21, 21, 21] (In total,4560 pathces with size 212121)

I just wanna know can we have a formula to calculate the resulting number of patches based on the patch_size and stride_size??? How torch.unfold deals with the boarding of the images? padding with 0 or discard it if it’s not enough for a patch?

You do not give too much information on how the unfold in the documentation.

Thanks very much!!!

1 Like

You can use the formula of nn.Conv2d to calculate the output shape:

out ​= ⌊(input_size​+2×padding−dilation×(kernel_size−1)−1)/stride ​+ 1⌋

Padding won’t be used, so the sizes will be floored.
If you need to pad your input, you could use F.pad on your input before applying unfold.

2 Likes

@ptrblck Great, thanks, so I guess, the dilation will not be used, I just set it to be 1? Which fits the output in my case.

Yes, I guess you could add the dilation after using unfold somehow, if that is needed.

1 Like

I need some help - I have successfully managed to convert my image into a set of patches. But I also need to use the patches to reconstruct the image. I’m very new to PyTorch but it seems like there should be a reverse method of unfold to reconstruct the image from the patches.

nn.Fold or a view on your patches might work.
Could you post the code you are using in case you get stuck?

Thanks for getting back to me!

I have an image which is of size 5176x3793 and I am breaking it up into patches of 128x128 pixels with the channels in the first dimension. Here’s the code:

img.data.shape # torch.Size([3, 5176, 3793])
patches = img.data.unfold(0, 3, 3).unfold(1, 128, 128).unfold(2, 128, 128)
patches.shape # torch.Size([1, 40, 29, 3, 128, 128])

I am processing the patches individually and after doing so I would like to reconstruct the image from the patches. I have put together a short stub which works but I am wondering if there is a more elegant way to achieve the same results using some in built methods.

patches = patches[0]
pw = patches.shape[-2] # patch width
ph = patches.shape[-1] # patch height
w = patches.shape[0] * pw 
h = patches.shape[1] * ph
new_img = torch.zeros([3, w, h])
for i in range(patches.shape[0]):
    for j in range(patches.shape[1]):
        patch_pixels = patches[i][j] # shape [3, pw, ph]
        new_img[:, pw*i: pw*(i+1), ph*j: ph*(j+1)] = patch_pixels
new_img.shape # torch.Size([3, 5120, 3712])

I also have another question, using torch.Tensor.unfold can result in pixels being lost in the slicing process because of the dimensions not matching up. Is there a way to avoid pixels being lost and keeping the original dimensions of the image?

Since the current shape is not divisible by the kernel size and the stride without losing pixels, you would have to pad the input manually before applying unfold.
I’ve just used some values to increase the size to the next multiple of 128.
Then a permute with view will reconstruct the original input:

img = torch.randn(3, 5176, 3793)
img = F.pad(img, (23, 24, 36, 36))
patches = img.data.unfold(0, 3, 3).unfold(1, 128, 128).unfold(2, 128, 128)

re = patches.permute(0, 3, 1, 4, 2, 5).contiguous().view_as(img)
print((re == img).all())
> tensor(1, dtype=torch.uint8)
5 Likes

Thanks for the help!!! I really appreciate it.

This works perfectly - When reconstructing the patches, I just want to recover the original image without the padding so I’m just slicing re to get what I need.

Can you recommend a good tutorial or blog post to get more familiar with these tensor manipulation methods in PyTorch? The docs aren’t the most helpful at times.

I’m not sure if there is a good tutorial on these kind of slicing operations, but it might be a good idea to create one! :slight_smile:

3 Likes