Patch Making Does Pytorch have Anything to Offer?

I am in the process of making my first CNN challenge and so far what has amazed me is that Pytorch offers almost an easy fix to anything needed. Except for patch making.

I have 3 dimensional data samples ranging in the 500x500x500. Such huge piece of data ofcourse can’t be fed to the network in one piece and therefore Patching is required. In my case the dimensions are not consistent and I do not allow for padding hence some fairly complex coding is required from me to create and reassemble these patches.

Say you want to create 64x64x64 patches out of the 500x500x500 data sample what do you guys do ?

Pytorch solves batch making for you with the dataloader but what about patch making? does Pytorch really not have anything to offer in that regard?

3 Likes

Just use indexing

image[64:128, :, :]

And, I don’t think patching is a very general need.

I am sorry that i pissed you off. Was not my intend

Well the suggestion you made is something I am already doing. In my case that is a solution with hundreds of lines of code. It i will involve atleast 3 nested for loops.

and when you do not allow for padding and you wish overlapping of the data instead it becomes even more complex.

I don’t now if it is not general. I can tell by now that it must be needed in almost any project that involves medical imaging.

Sorry for making you think I’m pissed off. English is not my first language I’m just making suggestions :crazy_face:

sorry your english is fine. It was me who misread. English is not my first language either. It will not remove that comment before 24 hours :confused:

Sorry for the misunderstanding

I think pathing is something belongs to data pre processing or data argumentation like random 5 or 10 crop?

I think .unfold would work in this case.
Have a look at this code:

x = torch.randn(1, 500, 500, 500)  # batch, c, h, w
kc, kh, kw = 64, 64, 64  # kernel size
dc, dh, dw = 64, 64, 64  # stride
patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
patches = patches.contiguous().view(patches.size(0), -1, kc, kh, kw)
print(patches.shape)
> torch.Size([1, 343, 64, 64, 64])

As you can see patches will give you 7*7*7=343 patches each of shape [64, 64, 64].
If you would like to overlap the patches, you should change the stride for each dimension.
Let me know, if that works for you.

1 Like

Can you reassemble them using a similar approach?

500 % 64 is not zero so ofcourse you cannot make the orignal 500x500x500 again but can you reassemble the 343 patches back into something that makes sense?

Imagine 500x500x500 being a ct screening

Yes, this would be possible with come permutations and reshaping.
Here is the corresponding code:

x = torch.randn(1, 500, 500, 500)  # batch, c, h, w
kc, kh, kw = 64, 64, 64  # kernel size
dc, dh, dw = 64, 64, 64  # stride
patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.contiguous().view(patches.size(0), -1, kc, kh, kw)
print(patches.shape)

# Reshape back
patches_orig = patches.view(unfold_shape)
output_c = unfold_shape[1] * unfold_shape[4]
output_h = unfold_shape[2] * unfold_shape[5]
output_w = unfold_shape[3] * unfold_shape[6]
patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
patches_orig = patches_orig.view(1, output_c, output_h, output_w)

# Check for equality
print((patches_orig == x[:, :output_c, :output_h, :output_w]).all())
8 Likes

Thank you. You have made my day!

1 Like

I have been playing around with your fine example.

Im not sure it works correctly when it comes to changing the stride.

consider the dimensions 407x301x360 and you define a stride of 20 and kernel of 64
According to my calculation that should yield 20x15x18 = 5400 patches. Alas your code only yields 3240 patches.

I might be calculating it wrong.

In my project I create 64x dim patches and when processed through the network they yield 24x dim result patches. These patches needs needs to be created with enough overlapping to capture the entire screening and they need to be put back together in the same order the patches were extracted from the original screening.


This is a little illustration i made.
The black box is the original screening
the small red boxes are result 24x dim patches
the gray box is a 64x patch.
the different color boxes is an illustration on how the stride has to work

How did you calculate the number of patches?
According to the nn.Conv2d docs you should get 18*12*15=3240 patches:

c = 407
h = 301
w = 360
dilation = 1
padding = 0
kernel_size = 64
stride = 20

print((c + 2*padding - dilation * (kernel_size - 1) - 1) / stride + 1)
print((h + 2*padding - dilation * (kernel_size - 1) - 1) / stride + 1)
print((w + 2*padding - dilation * (kernel_size - 1) - 1) / stride + 1)

407/20 = 20
301/20 = 15
360/20 = 18

20x15x18 = 5400

I am using conv3d. a ct screening is not a 2 dimensional image it is a 3 dimensional voxel space. (nifti/dicom)

But as far as i can see x in your example is 3 dimensional(technically 4). So the problem is perhaps that you are only striding along 2 axis and not 3 ?

Only guessing and perhaps I am the one miss calculating.

Yeah, I assumed you are using a 3-dim medical image. That’s why I also used the “channel” dimension to create the patches. Otherwise the complete channels would be used in each 2-dim patch.

The size calculation would be the same for nn.Conv3d, but apparently I posted the definition for nn.Conv2d :wink:
Your calculations e.g. ignore the kernel size and assume some padding. Note that the “last” kernel might not fit into your input for a certain stride.

Hi ptrblck, I’ve been trying unfold, how would I get the batch size right? should I just feed the reshaped tensor with (num_patches_in_y *num_patches_in_x * batch_size, height, width, channel) shape?

Assuming you are using this code snippet, then you wouldn’t have to change anything besides assigning a different batch size to the input and view operation:

# changes needed
batch_size = 16
x = torch.randn(batch_size, 500, 500, 500)  # batch, c, h, w
...
patches_orig = patches_orig.view(batch_size, output_c, output_h, output_w)

since the code is not applying any operation on the batch dimension and thus keeps it equal.

1 Like