Slicing image into square patches

Is there a slicing function to slice an image into square patches of size p. Right now I am doing this in few lines of code but if there is a function that does this I would just replace.

Look into torchvision.transforms.*crop methods. http://pytorch.org/docs/master/torchvision/transforms.html
Afaik, there is no default method for an arbitrary location crop, but it is easily implementable using torchvision.transforms.Lambda.

Quite late but for reference for others, you can use the unfold function:

patches = img_t.data.unfold(0, 3, 3).unfold(1, 8, 8).unfold(2, 8, 8)

Here is the demo code to test:

import torch
from torchvision import transforms
import matplotlib.pyplot as plt

%matplotlib inline

transt = transforms.ToTensor()
transp = transforms.ToPILImage()
img_t = transt(Image.open('cifar/train/10000_automobile.png'))

#torch.Tensor.unfold(dimension, size, step)
#slices the images into 8*8 size patches
patches = img_t.data.unfold(0, 3, 3).unfold(1, 8, 8).unfold(2, 8, 8)


print(patches[0][0][0].shape)

def visualize(patches):
    """Imshow for Tensor."""    
    fig = plt.figure(figsize=(4, 4))
    for i in range(4):
        for j in range(4):
            inp = trans1(patches[0][i][j])
            inp = np.array(inp)

            ax = fig.add_subplot(4, 4, ((i*4)+j)+1, xticks=[], yticks=[])
            plt.imshow(inp)

visualize(patches)
6 Likes