Implementing Early Convolutions for Improved Accuracy

So, I followed the paper titled ‘Early Convolutions Help Transformers See Better’ to implement a new convolutional stem (convStem) for my pre-trained Vision Transformer (ViT). However, when I attempted to fine-tune the model, the accuracy dropped from 45% to 2% after a few epochs. The paper suggested that this modification should enhance accuracy, but I am uncertain about what went wrong. Here is the code for my model:

class ConvStem(nn.Module):
    ConvStem, from Early Convolutions Help Transformers See Better, Tete et al.
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        assert patch_size == 16, 'ConvStem only supports patch size of 16'
        assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten
        # build stem, similar to the design in
        stem = []
        input_dim, output_dim = 3, embed_dim // 8
        for l in range(4):
            stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
            input_dim = output_dim
            output_dim *= 2
        stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
        self.proj = nn.Sequential(*stem)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

Code for my vision transfomer

class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        # self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)
        self.transformer = Transformer(width, layers, heads)
        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
        self.conv_layers = ConvStem(img_size=input_resolution,
                                    norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_dim=width)

    def forward(self, x: torch.Tensor):
        # x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        # x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x =[ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x +
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])
        if self.proj is not None:
            x = x @ self.proj
        return x

Can you show where you instantiate the model? Also, are you freezing the weights of the pretrained ViT model before training?


currect_model_dict = model.state_dict()
matched_pretrained = {k: v for k, v in pretrained_dict.items() if k in currect_model_dict.keys()}
def _initialize_missing_weights(model, pretrained_dict):
      for name, param in model.named_parameters():
          if name not in pretrained_dict:
              print(f"Initializing missing weight: {name}")
              trunc_normal_(param, std=0.02)
_initialize_missing_weights(model, matched_pretrained)

I tried two fine-tuning approaches: adjusting all Vision Transformer (ViT) layers, including the convStem, and alternatively, freezing all ViT layers and solely training the convStem.

It’s funny. I actually had that paper downloaded on my phone from a long time ago and vaguely remember reading it.

Anyway, it’s a simple matter. Here is what I need to see:

Can you show the code for:

  1. Where you define the optimizer;
  2. Where you define both the ViT and ConvStem models;
  3. Where you freeze the weights of the ViT model.

Thanks for helping me out. I am using a pre-trained ViT-b/16 model with pretrained CLIP.

def load_clip_to_cpu(backbone_name):
  url = clip._MODELS[backbone_name]
  model_path = clip._download(url)
    model = torch.jit.load(model_path, map_location="cpu").eval()
    state_dict = None
  except RuntimeError:
    state_dict = torch.load(model_path, map_location="cpu")
  pretrained_dict = state_dict or model.state_dict()
  model = clip.build_model(pretrained_dict)
  return model

The strucutre of my code is like that

  • : I define load_clip_to_cpu function

  • I define the train funciton as define the optimizer like that
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

  • model : I define the CLIP ViT-b/16 like that

from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from timm.models.layers.helpers import to_2tuple
from functools import partial, reduce
from timm.models.layers import trunc_normal_

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1):
        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = None
        self.stride = stride
        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
    def forward(self, x: torch.Tensor):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads
    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
        x =[x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
  [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
        return x[0]

class ModifiedResNet(nn.Module):
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        self.output_dim = output_dim
        self.input_resolution = input_resolution
        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.avgpool = nn.AvgPool2d(2)
        self.relu = nn.ReLU(inplace=True)
        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]
        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))
        return nn.Sequential(*layers)
    def forward(self, x):
        def stem(x):
            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
                x = self.relu(bn(conv(x)))
            x = self.avgpool(x)
            return x
        x = x.type(self.conv1.weight.dtype)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.attnpool(x)
        return x

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""
    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask
    def attention(self, x: torch.Tensor):
        self.attn_mask =, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int

        self.context_length = context_length

        if isinstance(vision_layers, (tuple, list)):
            vision_heads = vision_width * 32 // 64
            self.visual = ModifiedResNet(
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(
        self.transformer = Transformer(
        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)
        if isinstance(self.visual, ModifiedResNet):
            if self.visual.attnpool is not None:
                std = self.visual.attnpool.c_proj.in_features ** -0.5
                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
                for name, param in resnet_block.named_parameters():
                    if name.endswith("bn3.weight"):
        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def dtype(self):
        # return self.visual.conv1.weight.dtype
        return self.visual.conv_layers.proj[0].weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return x

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)
        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logit_scale * text_features @ image_features.t()
        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

def convert_weights(model: nn.Module):
    """Convert applicable model parameters to fp16"""
    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            if l.bias is not None:
        if isinstance(l, nn.MultiheadAttention):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:

def build_model(pretrained_dict: dict):
    vit = "visual.proj" in pretrained_dict
    # ---------------- load Vision Transformer --------------------------------
    if vit:
        vision_width = pretrained_dict["visual.conv1.weight"].shape[0]
        vision_layers = len([k for k in pretrained_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = pretrained_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((pretrained_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
        counts: list = [len(set(k.split(".")[2] for k in pretrained_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = pretrained_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((pretrained_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == pretrained_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32
    # ---------------- load Text Transformer --------------------------------
    embed_dim = pretrained_dict["text_projection"].shape[1]
    context_length = pretrained_dict["positional_embedding"].shape[0]
    vocab_size = pretrained_dict["token_embedding.weight"].shape[0]
    transformer_width = pretrained_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in pretrained_dict if k.startswith("transformer.resblocks")))

    model = CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in pretrained_dict:
            del pretrained_dict[key]

    currect_model_dict = model.state_dict()
    matched_pretrained = {k: v for k, v in pretrained_dict.items() if k in currect_model_dict.keys()}
    def _initialize_missing_weights(model, pretrained_dict):
        for name, param in model.named_parameters():
            if name not in pretrained_dict:
                print(f"Initializing missing weight: {name}")
                trunc_normal_(param, std=0.02)
    _initialize_missing_weights(model, matched_pretrained)
    # convert_weights(model)
    return model.eval()```

A couple of points:

  1. Your optimizer is taking the parameters for everything. That means it’s going to be optimizing all, including the pretrained weights. You should only pass into the optimizer precisely what you want to be optimized. Here is a short example:
import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(1, 2),nn.BatchNorm1d(2), nn.Linear(2,1))

#only pass in the second Linear layer for training to the optimizer
optimizer = torch.optim.Adam(model[2].parameters(), lr=0.001)

#may as well define a loss function
criterion = nn.MSELoss()

model = model.eval() # this will just turn off things like batchnorm and dropout, it does not shut off autograd

dummy_data = torch.rand(10, 1)
output = model(dummy_data)


loss = criterion(output, torch.rand(10,1))
# weights before optimization step
for param in model.parameters():


# compare which weights updated and which did not:
for param in model.parameters():
  1. model.eval() is going to change the behavior of certain things like batchnorm and shut off dropout layers, which we don’t want during inference time. You will also likely want eval() activated on only the pretrained model. If it’s of a module class, you should be able to shut off the eval for that class object individually. For example:
model.visual = model.visual.eval() # turn eval on for all
model.visual.conv_layers = model.visual.conv_layers.train() # go back and turn train on for your ConvStem
  1. To speed up training, you can also turn off grad for any parts of the model that don’t need it via model.layer1.requires_grad=False. I.e. the layers not being optimized.

  2. Off topic question, I thought the original paper used the ConvStem before the ViT. But it seems you have it after.

But, I put it before the ViT

        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        # x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x =[ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x +
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])
        if self.proj is not None:
            x = x @ self.proj```

I see. Didn’t look at the forward pass.

The issue lies not in the training itself. It appears that even when using the pretrained model weights with Convstem for a straightforward inference without training, the accuracy drops significantly from 45% to 2% on the test dataset.

That’s to be expected. If you insert a new layer on any model for additional finetuning, it’s going to initially have very poor results(except under very specific circumstances).

Here is a finetuning tutorial: TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 2.2.0+cu121 documentation

You can seen the initial loss starts out high, but quickly drops.

Sure, but this goes against what a “Vision Transformer for Contrastive Clustering” suggested about improving accuracy. I did not expect this huge drop.

Your original post stated:

  1. Pretrained ViT;
  2. Newly initialized ConvStem;
  3. After some epochs accuracy drops from 45% to 2%.

After looking at your code, we determined you were retraining the entire model and not just the ConvStem.

A significant drop in accuracy is to be expected in that case. This is because you’re starting with a high learning rate and new optimizer which are also being applied to the ViT pretrained weights. It’s the equivalent of taking a boulder and dropping it on an ice sculpture.

For instance, if you had model weights which you planned to continue training on, ideally you saved the optimizer so you have both the decayed learning rate and the parameter-wise optimizer values. Thus you aren’t starting from scratch. But, in case you didn’t have the optimizer saved, you could start with a very low learning rate and incrementally increase it once the parameter-wise optimizer values adjust to your data and weights. But if you start with a new optimizer and high learning rate, it is basically very disruptive on the weights initially. Which is why I suggest freezing the pretrained weights and just training the ConvStem.

