Tf.extract_image_patches in pytorch

I’m not sure why the method is called extract_image_patches if you won’t get the patches, but apparently a view of [batch_size, height, width, channels*kernel_height*kernel_width].

However, this code should yield the same result in PyTorch:

import torch
import torch.nn.functional as F

batch_size = 128
channels = 16
height, width = 32, 32
x = torch.randn(batch_size, channels, height, width)

kh, kw = 3, 3
dh, dw = 1, 1

# Pad tensor to get the same output
x = F.pad(x, (1, 1, 1, 1))

# get all image windows of size (kh, kw) and stride (dh, dw)
patches = x.unfold(2, kh, dh).unfold(3, kw, dw)
print(patches.shape)  # [128, 16, 32, 32, 3, 3]
# Permute so that channels are next to patch dimension
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()  # [128, 32, 32, 16, 3, 3]
# View as [batch_size, height, width, channels*kh*kw]
patches = patches.view(*patches.size()[:3], -1)
print(patches.shape)
> torch.Size([128, 32, 32, 144])

Note that in PyTorch the channel dimension is in dim1, so I changed your input shape to match the PyTorch conversion. :wink:

15 Likes