I am training a simple ViT using the CIFAR10 dataset.
The training loop is always giving the same loss 2.3
The patching module is:
class Patcher(nn.Module):
def __init__(self, patch_size):
super(Patcher, self).__init__()
self.patch_size=patch_size
self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
def forward(self, images):
batch_size, channels, height, width = images.shape
patch_height, patch_width = [self.patch_size, self.patch_size]
assert height % patch_height == 0 and width % patch_width == 0, "Height and width must be divisible by the patch size."
patches = self.unfold(images) #bs (cxpxp) N
patches = patches.view(batch_size, channels, patch_height, patch_width, -1).permute(0, 4, 1, 2, 3) # bs N C P P
return patches
The Transformer Block is:
class TransformerBlock(nn.Module):
def __init__(self, model_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super(TransformerBlock, self).__init__()
self.norm1 = nn.LayerNorm(model_dim)
self.attn = nn.MultiheadAttention(model_dim, num_heads, dropout=dropout)
self.norm2 = nn.LayerNorm(model_dim)
# Feedforward network
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim),
nn.Dropout(dropout),
)
def forward(self, x):
# Self-attention
norm1 = self.norm1(x)
attn_out, _ = self.attn(norm1, norm1, norm1)
x = x + attn_out
# Feedforward network
norm2 = self.norm2(x)
mlp_out = self.mlp(norm2)
x = x + mlp_out
return x
The ViT is:
class ViT_RGB(nn.Module):
def __init__(self, img_size, patch_size, model_dim= 30, num_heads=3, num_layers=2, out_dim=10):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (self.img_size // self.patch_size) ** 2
self.model_dim = model_dim
self.num_layers = num_layers
self.num_heads= num_heads
self.out_dim = out_dim
# 1) Patching
self.patcher = Patcher(patch_size=self.patch_size)
# 2) Linear Prjection
self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
# 3) Class Token
self.class_token = nn.Parameter(torch.rand(1, self.model_dim)) # This common for all images! TODO
# 4) Positional Embedding
#self.positional_embedding = nn.Embedding(self.n_patches+1, self.model_dim) # +1 for the class token # Common for all images! TODO
self.positional_embedding = nn.Parameter(torch.zeros(1,(img_size // patch_size) ** 2 + 1, model_dim))
# 5) Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock( self.model_dim, self.num_heads) for _ in range(num_layers)
])
# 6) Classification MLPk
self.mlp = nn.Sequential(
nn.Linear(self.model_dim, self.out_dim),
nn.Softmax(dim=-1)
)
def forward(self, images):
patches = self.patcher(images)
patches = patches.flatten(start_dim=2)
patches = self.linear_projector(patches)
batch_size = patches.shape[0]
class_tokens = self.class_token.expand(batch_size, -1, -1)
tokens = torch.cat((class_tokens, patches), dim=1)
#positions = torch.arange(tokens.shape[1])
positions = self.positional_embedding
tokens = tokens + positions
for block in self.blocks:
tokens = block(tokens)
out = tokens[:, 0]
out = self.mlp(out)
return out