I’m not sure why the method is called extract_image_patches
if you won’t get the patches, but apparently a view of [batch_size, height, width, channels*kernel_height*kernel_width]
.
However, this code should yield the same result in PyTorch:
import torch
import torch.nn.functional as F
batch_size = 128
channels = 16
height, width = 32, 32
x = torch.randn(batch_size, channels, height, width)
kh, kw = 3, 3
dh, dw = 1, 1
# Pad tensor to get the same output
x = F.pad(x, (1, 1, 1, 1))
# get all image windows of size (kh, kw) and stride (dh, dw)
patches = x.unfold(2, kh, dh).unfold(3, kw, dw)
print(patches.shape) # [128, 16, 32, 32, 3, 3]
# Permute so that channels are next to patch dimension
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous() # [128, 32, 32, 16, 3, 3]
# View as [batch_size, height, width, channels*kh*kw]
patches = patches.view(*patches.size()[:3], -1)
print(patches.shape)
> torch.Size([128, 32, 32, 144])
Note that in PyTorch the channel dimension is in dim1, so I changed your input shape to match the PyTorch conversion.