Is there a slicing function to slice an image into square patches of size p
. Right now I am doing this in few lines of code but if there is a function that does this I would just replace.
Look into torchvision.transforms.*crop
methods. http://pytorch.org/docs/master/torchvision/transforms.html
Afaik, there is no default method for an arbitrary location crop, but it is easily implementable using torchvision.transforms.Lambda
.
Quite late but for reference for others, you can use the unfold function:
patches = img_t.data.unfold(0, 3, 3).unfold(1, 8, 8).unfold(2, 8, 8)
Here is the demo code to test:
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
transt = transforms.ToTensor()
transp = transforms.ToPILImage()
img_t = transt(Image.open('cifar/train/10000_automobile.png'))
#torch.Tensor.unfold(dimension, size, step)
#slices the images into 8*8 size patches
patches = img_t.data.unfold(0, 3, 3).unfold(1, 8, 8).unfold(2, 8, 8)
print(patches[0][0][0].shape)
def visualize(patches):
"""Imshow for Tensor."""
fig = plt.figure(figsize=(4, 4))
for i in range(4):
for j in range(4):
inp = trans1(patches[0][i][j])
inp = np.array(inp)
ax = fig.add_subplot(4, 4, ((i*4)+j)+1, xticks=[], yticks=[])
plt.imshow(inp)
visualize(patches)
5 Likes