Image patches for ViT

Hi,
I’m trying to recreate ViT from scratch. Firstly I wanna do Patches class and when I’m trying to see result… I get something weird.
Here is a link for images for more explanation:
https://imgur.com/a/DNgNFHL

Here is my code:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from pathlib import Path
from PIL import Image
from torchvision import transforms

class Patches(nn.Module):
    def __init__(self, img_size, patch_size, in_channel=3, embed_dim=300) -> None:
        super(Patches, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.conv = nn.Conv2d(
            in_channels=in_channel,
            out_channels=embed_dim,
            kernel_size=(patch_size, patch_size),
            stride=patch_size
        )
    
    def forward(self, image):
        image = self.conv(image)
        image = image.flatten(2)
        image = image.transpose(1,2)

        return image
    
#TESTING#
img_path = Path('IMAGE')
image = Image.open(img_path)
image_size = 300
patch_size = 10

plt.figure(figsize=(4,4))
plt.imshow(image)
plt.axis('off')

resized_image = image.resize([image_size, image_size])
ToTensor = transforms.ToTensor()
ToPIL = transforms.ToPILImage()
tensor_image = ToTensor(resized_image)
tensor_image = tensor_image.unsqueeze(dim=0)
print(f'Image size: {tensor_image.shape}')
patches_extractor = Patches(img_size=image_size, patch_size=patch_size)
patches = patches_extractor(tensor_image)
print(f'Patches: {patches.shape}')

print(f'Image size: {image_size} X {image_size}')
print(f'Patch size: {patch_size} X {patch_size}')
print(f'Patches per image: {patches.shape[1]}')
print(f'Elements per patch: {patches.shape[-1]}')
n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4,4))
# to_txt = patches[0].numpy()
# np.savetxt('pytorch.txt',to_txt)
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i+1)
    patch_img = torch.reshape(patch,(3,patch_size, patch_size))
    plt.imshow(ToPIL(patch_img))
    plt.axis('off')
plt.show()

I have no idea what is wrong.

I’m unsure what exactly is unexpected as it seems you are passing patches to an uninitialized model and see noise in the output (which I would expect without training the model at all).
However, I also don’t fully understand which dimensions you are transposing in the forward method and why you are doing it.

You’re right I made some misunderstanding. I wanna recreate equivalent code in PyTorch. Here is code in Tensorflow:

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow_addons as tfa
import glob, random, os, warnings
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from pathlib import Path
from PIL import Image

class Patches(L.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images = images,
            sizes = [1, self.patch_size, self.patch_size, 1],
            strides = [1, self.patch_size, self.patch_size, 1],
            rates = [1, 1, 1, 1],
            padding = 'VALID',
        )
        print(f'Forward patches: {patches.shape}')
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches


img_path = Path('Image')
image = Image.open(img_path)
image_size = 300
patch_size = 10

plt.figure(figsize=(4, 4))

plt.imshow(image)
plt.axis('off')

resized_image = tf.image.resize(
    np.array(image), size = (image_size, image_size)
)
resized_image = tf.convert_to_tensor(resized_image)
resized_image = tf.expand_dims(resized_image, 0)
print(f'Image size: {resized_image.shape}')

patches = Patches(patch_size)(resized_image)
print(f'Patches: {patches.shape}')

print(f'Image size: {image_size} X {image_size}')
print(f'Patch size: {patch_size} X {patch_size}')
print(f'Patches per image: {patches.shape[1]}')
print(f'Elements per patch: {patches.shape[-1]}')

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
# to_txt = patches[0].numpy()
# np.savetxt('tensorflow.txt',to_txt)
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype('uint8'))
    plt.axis('off')
plt.show()

And what I need is this image in example chopped into patches. I know it is possible to make this patches with some module, but I wanna use only PyTorch. So I found the function in Tensorflow tf.image.extract_patches() and I wanna do the same in PyTorch. I wanna add this chopped image into my documentation. I hope this brings some important information.

I also made another solution using function called unfold() with transpose() but it gives me the same noisy image.

I’m new in PyTorch and Computer Vision problems so sorry if I write something that doesn’t make any sense.

Using unfold is the right approach and works for me:

img = PIL.Image.open(PATH)

x = transforms.ToTensor()(img)
y = x.unfold(1, 100, 100).unfold(2, 100, 100)
y = y.contiguous().view(y.size(0), -1, 100, 100)
y = y.permute(1, 0, 2, 3).contiguous()
print(y.shape)

fig, axarr = plt.subplots(8, 8)
axarr = axarr.reshape(-1)
for idx, y_ in enumerate(y):
    axarr[idx].imshow(y_.permute(1, 2, 0).contiguous().numpy())
1 Like

Thank you, I made it to complicated. Your solution is much easier.