Why Conv2d is used instead of regular Patchify in ViT Pytorch?

Hi there.
In Pytorch implementation of ViT, Conv2d is used over regular Patchify. in other words, researchers in An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale proposed framework which receives image in a number of pieces and processes it based on self-attention mechanism. but in Pytorch version, Conv2d is used instead of that.

VisionTransformer Class:

def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x
def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        x = self.heads(x)

        return x

Can anyone explain for me why Conv2d is used? why not regular Patchify?!

if conv_stem_configs is not None:
         >>># As per https://arxiv.org/abs/2106.14881<<<
            seq_proj = nn.Sequential()
            prev_channels = 3
            for i, conv_stem_layer_config in enumerate(conv_stem_configs):
                prev_channels = conv_stem_layer_config.out_channels
                "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
            self.conv_proj: nn.Module = seq_proj
            self.conv_proj = nn.Conv2d(
                in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size

It says in the comments to see this paper: [2106.14881] Early Convolutions Help Transformers See Better

I’m sure they did it because the data shows it to be an improvement over not doing so.

Seems the original Google ViT model had a “patchify stem”, which had a conv layer with a large 16x16 kernel.

But specifying this option in the Pytorch ViT changes that stem kernel to be a custom size, which the above paper demonstrates that size 3x3 followed by a 1x1 performs better.

1 Like

Thought it might be helpful to add how to build a custom patchify stem, given the docs do not contain an example. This repeats what they did in the paper:

from torchvision.models.vision_transformer import ConvStemConfig, vit_b_16

n = 4
stemconfig = [ConvStemConfig(out_channels = 64, kernel_size = 3 , stride = 2) for i in range(n)]

model = vit_b_16(pretrained=False, conv_stem_configs=stemconfig)
1 Like

Thank you so much for this code example!
Do you have any idea why researchers not mentioned using convolution instead of regular patchify in ViT original paper? I mean according to below figure and its description, regular patchify is used…
Maybe their mean about linearly embedding is convolution operation.

If you flatten each 16x16 segment and apply a linear layer to turn each into a size of hidden_dim, you’re effectively doing the same as applying a Conv2d with kernel and stride of 16x16 with output of hidden_dim. But a convolution is likely a preferred operation, if I recall correctly, because it’s faster for GPUs.

However, I cannot speak to the actual reasoning of the developers and the above is my own speculation regarding their reasoning.

1 Like