Tf.image.extract_patches with "VALID" padding

I’ve been able to reimplement the extract_patches and extract_image_patches in PyTorch for “SAME” padding but how would I implement it for “VALID” padding when the strides and kernel sizes are different in the height and width dimensions?

For example:

batch_size = 4
channels = 128
height, width = 1, 600
kernel_height, kernel_width = 1, 16
stride_height, stride_width = 1, 4

x = torch.arange(0, batch_size*width*height*channels).view(batch_size, channels, height, width) # shape [4, 128, 1, 600]

tf_x = x.permute(0,2,3,1).numpy()

patches = x.unfold(2, kernel_height, stride_height).unfold(3, kernel_width, stride_width) # [4, 1, 147, 128, 1, 16]

patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()  # [4, 1, 147, 128, 1, 16]

patches = patches.view(*patches.size()[:3], -1) # [4, 1, 147, 2048]
print(patches[0,0,0,:])
tf_patches = tf.image.extract_patches(images=tf_x,
                          sizes=[1, 1, kernel_width, 1],
                          strides=[1, 1, stride_width, 1],
                          rates=[1, 1, 1, 1],
                          padding='VALID') # [4, 1, 147, 2048]
print(tf_patches.numpy()[0,0,0,:])

assert torch.allclose(
    patches,
    torch.from_numpy(tf_patches.numpy()),
    atol=1e-3,
    rtol=1e-3
)
> tensor([    0,     1,     2,  ..., 76213, 76214, 76215])
> [    0   600  1200 ... 75015 75615 76215]

The sizes are the same but the order of the elements are wrong.

Nevermind, the unfolding has to be done first across the height because the shape is 1.

Fix:

  bs, ch, h, w = x.size()
  patches = x.unfold(3,kw,sw).unfold(2, kh, sh)
  patches = patches.permute(0,4,5,1,2,3).contiguous()
  patches = patches.view(patches.shape[0],-1,patches.shape[-2], patches.shape[-1])