Hi,
I’m trying to recreate ViT from scratch. Firstly I wanna do Patches class and when I’m trying to see result… I get something weird.
Here is a link for images for more explanation:
https://imgur.com/a/DNgNFHL
Here is my code:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from pathlib import Path
from PIL import Image
from torchvision import transforms
class Patches(nn.Module):
def __init__(self, img_size, patch_size, in_channel=3, embed_dim=300) -> None:
super(Patches, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.conv = nn.Conv2d(
in_channels=in_channel,
out_channels=embed_dim,
kernel_size=(patch_size, patch_size),
stride=patch_size
)
def forward(self, image):
image = self.conv(image)
image = image.flatten(2)
image = image.transpose(1,2)
return image
#TESTING#
img_path = Path('IMAGE')
image = Image.open(img_path)
image_size = 300
patch_size = 10
plt.figure(figsize=(4,4))
plt.imshow(image)
plt.axis('off')
resized_image = image.resize([image_size, image_size])
ToTensor = transforms.ToTensor()
ToPIL = transforms.ToPILImage()
tensor_image = ToTensor(resized_image)
tensor_image = tensor_image.unsqueeze(dim=0)
print(f'Image size: {tensor_image.shape}')
patches_extractor = Patches(img_size=image_size, patch_size=patch_size)
patches = patches_extractor(tensor_image)
print(f'Patches: {patches.shape}')
print(f'Image size: {image_size} X {image_size}')
print(f'Patch size: {patch_size} X {patch_size}')
print(f'Patches per image: {patches.shape[1]}')
print(f'Elements per patch: {patches.shape[-1]}')
n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4,4))
# to_txt = patches[0].numpy()
# np.savetxt('pytorch.txt',to_txt)
for i, patch in enumerate(patches[0]):
ax = plt.subplot(n, n, i+1)
patch_img = torch.reshape(patch,(3,patch_size, patch_size))
plt.imshow(ToPIL(patch_img))
plt.axis('off')
plt.show()
I have no idea what is wrong.