How to crop each element of a tensor with a window size at the same time without using for loop

I have a tensor with a size of 3 x 240 x 320 and I want to use a window (window size ws=5) to slide over each pixel and crop out the sub-tensor centered at that pixel. The final output dimension should be 3 x ws x ws x 240 x 320. So I pad the original tensor with window size and use for loop to do so.

import torch.nn.functional as F

image = torch.randn(1, 3, 240, 320)
image = F.pad(image, (ws // 2, ws // 2, ws // 2, ws // 2), mode='reflect')
patches = torch.zeros(1, 3, ws, ws, 240, 320)
for i in range(height):
    for j in range(width):
        patches[:, :, :, :, i, j] = image[:, :, i:i+ws, j:j+ws]

Are there any ways to do the cropping of each pixel at the sample time? Like without using the for loop over each pixel? I feel like it’s pretty similar to convolution operation but I can’t think of wats to crop efficiently. Thanks in advance!

Take a moment to flip through the documents, there will be unexpected gains.
https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html#torch.nn.Unfold

1 Like

great, thanks! I was able to get what I want using this

import torch.nn.functional as F

image = torch.randn(1, 3, 240, 320)
image = F.pad(image, (ws // 2, ws // 2, ws // 2, ws // 2), mode='reflect')
patches = F.unfold(a, kernel_size=ws).view(1, 3, ws**2, 240, 320)