What is wrong with my ViT

I am training a simple ViT using the CIFAR10 dataset.
The training loop is always giving the same loss 2.3

The patching module is:

class Patcher(nn.Module):
  def __init__(self, patch_size):
    super(Patcher, self).__init__()
    self.patch_size=patch_size
    self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)

  def forward(self, images):
    batch_size, channels, height, width = images.shape
    patch_height, patch_width = [self.patch_size, self.patch_size]
    assert height % patch_height == 0 and width % patch_width == 0, "Height and width must be divisible by the patch size."

    patches = self.unfold(images) #bs (cxpxp) N
    patches = patches.view(batch_size, channels, patch_height, patch_width, -1).permute(0, 4, 1, 2, 3) # bs N C P P

    return patches

The Transformer Block is:

class TransformerBlock(nn.Module):
    def __init__(self, model_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(model_dim)
        self.attn = nn.MultiheadAttention(model_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(model_dim)

        # Feedforward network
        self.mlp = nn.Sequential(
            nn.Linear(model_dim, int(model_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(model_dim * mlp_ratio), model_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention
        norm1 = self.norm1(x)
        attn_out, _ = self.attn(norm1, norm1, norm1)
        x = x + attn_out

        # Feedforward network
        norm2 = self.norm2(x)
        mlp_out = self.mlp(norm2)
        x = x + mlp_out

        return x

The ViT is:

class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 30, num_heads=3, num_layers=2, out_dim=10):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim
    self.num_layers = num_layers
    self.num_heads= num_heads
    self.out_dim = out_dim

    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)
    
    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)

    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, self.model_dim)) # This common for all images! TODO

    # 4) Positional Embedding
    #self.positional_embedding = nn.Embedding(self.n_patches+1, self.model_dim) # +1 for the class token # Common for all images! TODO
    self.positional_embedding = nn.Parameter(torch.zeros(1,(img_size // patch_size) ** 2 + 1, model_dim))

    # 5) Transformer blocks
    self.blocks = nn.ModuleList([
        TransformerBlock( self.model_dim,  self.num_heads) for _ in range(num_layers)
    ])

    # 6) Classification MLPk
    self.mlp = nn.Sequential(
            nn.Linear(self.model_dim, self.out_dim),
            nn.Softmax(dim=-1)
        )

  def forward(self, images):

    patches = self.patcher(images)

    patches = patches.flatten(start_dim=2)
    patches = self.linear_projector(patches)

    batch_size = patches.shape[0]
    class_tokens = self.class_token.expand(batch_size, -1, -1)
    tokens = torch.cat((class_tokens, patches), dim=1)

    #positions = torch.arange(tokens.shape[1])
    positions = self.positional_embedding
    tokens = tokens + positions

    for block in self.blocks:
      tokens = block(tokens)

    out = tokens[:, 0]
    out = self.mlp(out)

    return out