`tf.image.extract_patches` in PyTorch

Hi there!

I am currently trying to reproduce the tf.image.extract_patches to my usecase that is summarised in this gist: from `tf` to `torch` extract to patches · GitHub

The implementations are not matching (the assertions does not pass) and I am not sure what I am doing wrong. I made sure that the interpolation gives the same output and the divergence starts from the respective “to patch” function.

I can also continue the thread in the original issue: Tf.extract_image_patches in pytorch - #8 by ptrblck as the problem is more or less similar

Below a simplified test I am trying to make it pass:

import tensorflow as tf
import torch

import torch.nn.functional as F
    

# adapted from: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/8
def torch_extract_patches(
    x, patch_height, patch_width, padding=None
):
    x = x.unsqueeze(0)
    if padding == "SAME":
        x = F.pad(x, (1, 1, 1, 1))

    patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
    # Permute so that channels are next to patch dimension
    patches = patches.permute(0, 2, 3, 1, 5, 4).contiguous()  # [128, 32, 32, 16, 3, 3]
    # View as [batch_size, height, width, channels*kh*kw]
    patches = patches.reshape(*patches.size()[:3], -1)
    return patches

# H x W x C
image_tf = tf.random.uniform(shape=(720, 720, 3))
# C x H x W
image_torch = torch.from_numpy(image_tf.numpy()).permute(2, 0, 1)
patch_height, patch_width = 16, 16

patches_tf = tf.image.extract_patches(
    images=tf.expand_dims(image_tf, 0),
    sizes=[1, patch_height, patch_width, 1],
    strides=[1, patch_height, patch_width, 1],
    rates=[1, 1, 1, 1],
    padding="SAME"
)

patches_torch = torch_extract_patches(
    x=image_torch,
    patch_height=patch_height,
    patch_width=patch_width,
    padding="SAME"
)

assert torch.allclose(
    patches_torch.squeeze(0),
    torch.from_numpy(patches_tf.numpy()[0, :, :, :]),
    atol=1e-3,
    rtol=1e-3
)

Any help is appreciated!

Thanks in advance

Could you check if the following modification works for you?

import tensorflow as tf
import torch

import torch.nn.functional as F


# adapted from: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/8
def torch_extract_patches(
    x, patch_height, patch_width, padding=None
):
    x = x.unsqueeze(0)
    #if padding == "SAME":
    #    x = F.pad(x, (1, 1, 1, 1))
    patches = F.unfold(x, (patch_height, patch_width), stride=(patch_height, patch_width))
    patches = patches.reshape(x.size(0), x.size(1), patch_height, patch_width, -1)
    patches = patches.permute(0, 4, 2, 3, 1).reshape(x.size(2)//patch_height, x.size(3)//patch_width, x.size(1)*patch_height*patch_width)
    return patches

# H x W x C
image_tf = tf.random.uniform(shape=(720, 720, 3))
# C x H x W
image_torch = torch.from_numpy(image_tf.numpy()).permute(2, 0, 1)
patch_height, patch_width = 16, 16

patches_tf = tf.image.extract_patches(
    images=tf.expand_dims(image_tf, 0),
    sizes=[1, patch_height, patch_width, 1],
    strides=[1, patch_height, patch_width, 1],
    rates=[1, 1, 1, 1],
    padding="SAME"
)

patches_torch = torch_extract_patches(
    x=image_torch,
    patch_height=patch_height,
    patch_width=patch_width,
    padding="SAME"
)

pyt_patches = patches_torch
tf_patches = torch.from_numpy(patches_tf.numpy()[0, :, :, :])

assert torch.allclose(
    pyt_patches,
    tf_patches,
    atol=1e-3,
    rtol=1e-3
)

I believe using nn.Unfold would be closer to your use case. Also please check the semantics of padding=="SAME" as I believe it caused zeros to be added to the PyTorch output.

1 Like

Amazing, this works, thanks a bunch!