Hi there!
I am currently trying to reproduce the tf.image.extract_patches
to my usecase that is summarised in this gist: from `tf` to `torch` extract to patches · GitHub
The implementations are not matching (the assertions does not pass) and I am not sure what I am doing wrong. I made sure that the interpolation gives the same output and the divergence starts from the respective “to patch” function.
I can also continue the thread in the original issue: Tf.extract_image_patches in pytorch - #8 by ptrblck as the problem is more or less similar
Below a simplified test I am trying to make it pass:
import tensorflow as tf
import torch
import torch.nn.functional as F
# adapted from: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/8
def torch_extract_patches(
x, patch_height, patch_width, padding=None
):
x = x.unsqueeze(0)
if padding == "SAME":
x = F.pad(x, (1, 1, 1, 1))
patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
# Permute so that channels are next to patch dimension
patches = patches.permute(0, 2, 3, 1, 5, 4).contiguous() # [128, 32, 32, 16, 3, 3]
# View as [batch_size, height, width, channels*kh*kw]
patches = patches.reshape(*patches.size()[:3], -1)
return patches
# H x W x C
image_tf = tf.random.uniform(shape=(720, 720, 3))
# C x H x W
image_torch = torch.from_numpy(image_tf.numpy()).permute(2, 0, 1)
patch_height, patch_width = 16, 16
patches_tf = tf.image.extract_patches(
images=tf.expand_dims(image_tf, 0),
sizes=[1, patch_height, patch_width, 1],
strides=[1, patch_height, patch_width, 1],
rates=[1, 1, 1, 1],
padding="SAME"
)
patches_torch = torch_extract_patches(
x=image_torch,
patch_height=patch_height,
patch_width=patch_width,
padding="SAME"
)
assert torch.allclose(
patches_torch.squeeze(0),
torch.from_numpy(patches_tf.numpy()[0, :, :, :]),
atol=1e-3,
rtol=1e-3
)
Any help is appreciated!
Thanks in advance