Tf.extract_image_patches in pytorch

Is there a function like tf.extract_image_patches in pytorch?
Thank you

I think tensor.unfold should correspond to this method.

is it support for overlapping as well?
Because i have a tensor size (128, 32, 32, 16)
with ksize = [1,3,3,1] stride [1,1,1,1] and rate [1,1,1,1]
In tensorflow it can gives me (128,32,32,144)
However, I cant figure out a way in Pytorch using unfold to give the same.
Thank you

Could you explain how the output is calculated in tensorflow?
If looks like you somehow reshaped the patches to the original input size (128x32x32)?

This is the output tensor from a convolution layer, 128 is batch size of 32*32 and output channel of that layer is 16. it is cifar10 with resnet56.

tensor.unfold will give you the image patches, so I assume you are applying some operation on each patch and reshape it to this activation shape afterwards?
What operation would you like to apply? Maybe a standard nn.Conv2d module will also do the trick?

In tensorflow, it passes this input tensor (128,32,32,16) into the tf.extract_image_patches and output a tensor with size (128,32,32,144) with kernel size (1,3,3,1) and stride (1,1,1,1) as the parameter for the tf.extract_image_patches. so i assume in the tf.extract_image_patches it do overlap sampling. And i tried tensor.unfold cannot perform this. Maybe it can and I dont know how it will output the same as tensorflow…
Using a nn.Conv2d is an idea, by setting the weights to 1? but it misses the sampling part…
Thank you again…

I’m not sure why the method is called extract_image_patches if you won’t get the patches, but apparently a view of [batch_size, height, width, channels*kernel_height*kernel_width].

However, this code should yield the same result in PyTorch:

import torch
import torch.nn.functional as F

batch_size = 128
channels = 16
height, width = 32, 32
x = torch.randn(batch_size, channels, height, width)

kh, kw = 3, 3
dh, dw = 1, 1

# Pad tensor to get the same output
x = F.pad(x, (1, 1, 1, 1))

# get all image windows of size (kh, kw) and stride (dh, dw)
patches = x.unfold(2, kh, dh).unfold(3, kw, dw)
print(patches.shape)  # [128, 16, 32, 32, 3, 3]
# Permute so that channels are next to patch dimension
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()  # [128, 32, 32, 16, 3, 3]
# View as [batch_size, height, width, channels*kh*kw]
patches = patches.view(*patches.size()[:3], -1)
print(patches.shape)
> torch.Size([128, 32, 32, 144])

Note that in PyTorch the channel dimension is in dim1, so I changed your input shape to match the PyTorch conversion. :wink:

5 Likes

Thank you so much for the reply! Really appreciate it! I will give a try! Thanks again!!

1 Like

Thanks @ptrblck for your solution.
Put it into a function with usual pytorch format BxCxHxW which sould mimic tf function. Since I am a beginner, there is probably a much more efficient way. Also I haven’t tried backprop yet…

import math
import torch.nn.functional as F

def extract_image_patches(x, kernel, stride=1, dilation=1):
    # Do TF 'SAME' Padding
    b,c,h,w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
    
    # Extract patches
    patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
    patches = patches.permute(0,4,5,1,2,3).contiguous()
    
    return patches.view(b,-1,patches.shape[-2], patches.shape[-1])