Not running forward -- ViT reconstructed

I am trying to remove the last fc layer of ViT and create my own fc layer.

I used this link to import vit_b_16

class ViT_df(nn.Module):
    def __init__(self, num_classes=2, model_name='vit16', pretrained=True):
        super().__init__()
        if model_name == 'vit16':
            model = vit_b_16(pretrained=pretrained)

        else:
            print('wrong model name')

        layers = list(model.children())[:-1]
        fc_size = model.hidden_dim
        self.parent = nn.Sequential(*layers)
        self.fc = nn.Linear(in_features=fc_size, out_features=num_classes)

    def forward(self, image):
        x = self.parent(image)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

However, it returns error:

AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([50, 768, 14, 14])

batch_size=50, seq_length = 14*14, hidden_dim=768 and obviously this error is due to not running def _process_input(self, x: torch.Tensor) -> torch.Tensor: function. (in this link)

I debugged my code. My reconstructed ViT model runs forward() in class Encoder(nn.Module): before running forward() in class VisionTransformer(nn.Module):. (in this link) It;s very weird for me. And I need your help.

Thank you!!!

What I assume is happening, is you are getting rid of the whole self.heads layer that comes after the Encoder, and not only the last nn.Linear that is inside of self.heads. You can print the list of layers that you are taking from the parent to be sure.

If this is the case, then you need to take only the class token before passing it through your custom nn.Linear.

# Only the class token
x = x[:, 0]

This should work fine to fine-tune.

class ViT_df(nn.Module):
    def __init__(self, num_classes=2, model_name='vit16', pretrained=True):
        super().__init__()
        if model_name == 'vit16':
            model = vit_b_16(pretrained=pretrained)

        else:
            print('wrong model name')

        layers = list(model.children())[:-1]
        fc_size = model.hidden_dim
        self.parent = nn.Sequential(*layers)
        self.fc = nn.Linear(in_features=fc_size, out_features=num_classes)

    def forward(self, image):
        x = self.parent(image)
        # Select ONLY the class token in the first column
        x = x[:, 0] 

        x = self.dropout(x)

        # I´m not sure you need this
        #x = x.view(x.size(0), -1)
        
        x = self.fc(x)

        return x

If you want to use the full MLP after the encoder, then you can either copy the full self.heads or make sure that you are only getting rid of the last layer inside self.heads and not the whole thing.

If you are getting rid of only the last layer and the problem is still there, please let me know.

Hope this helps.

P.S.: I am only assuming this, because in the link you posted, self.heads is defined as nn.Sequential with other layers as children. This means that the actual layer that you want to replace is a child of the child you are removing.

Thank you for your reply Matias, and I tried your suggestion. But I am afraid it is not my case.

  1. The error happens in x = self.parent(image), so it does not help if x = x[:, 0] is added after it.
  2. There is only one nn.Linear layer in heads if I don’t set representation_size which is my case.
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

Oh I see. As you mentioned, the image should have the shape B x Seq_len x Hidd_dim = 50 x (14*14 + class_token) x 768.

Could you print something like this to see in which exact layer the problem is?

def forward(self, x):
    for layer in self.parent:
        print(layer, x.shape)
        x = layer(x)

    x = self.dropout(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)

    return x

Yes, here it is.

Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)) torch.Size([50, 3, 224, 224])
Encoder(
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (encoder_layer_0): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_1): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_2): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_3): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_4): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_5): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_6): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_7): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_8): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_9): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_10): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_11): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (linear_1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (dropout_1): Dropout(p=0.0, inplace=False)
        (linear_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout_2): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
) torch.Size([50, 768, 14, 14])

You mentioned in your first post that you are not running _process_input(). I thought you meant that it was inside your Encoder or something like that. But you need to run this in order to have the expected shape.

Here is where the the patches are flattened.

Also I do not see the class token added.

Is there a reason you left the _process_input method out?

Edit: 1. In the implementation from your link, the positional embedding is added inside the encoder.
2. You can access the class_token from the parent, since it is defined as self.class_token = nn.Parameter(...) in the VisionTransformer class.

The same issue is being discussed here.

No functional API calls are done. In the comments I explain a bit of how this is done.

Please let me know if you still need help.

import torch
import torch.nn as nn
import torchvision

img = torch.randn(1, 3, 224, 224)


model = torchvision.models.vit_b_16()
feature_extractor = nn.Sequential(*list(model.children())[:-1])

# This is supposed to be the PREPROCESS
# But it is not done correctly, since the reshaping and permutation is not done
# Only the concolution
conv = feature_extractor[0]  

# -> print(conv(img).shape)
# -> torch.Size([1, 768, 14, 14])
# This is not the desired output after preprocessing the image into
# flat patches. Also in the pytorch implementation, the class token
# and positional embedding are done extra on the forward method.

# This is the whole encoder sequence
encoder = feature_extractor[1]

# The MLP head at the end is gone, since you only selected the children until -1
# mlp = feature_extractor[2]

# This is how the model preprocess the image.
# The output shape is the one desired 
x = model._process_input(img)

# -> print(x.shape)
# -> torch.Size([1, 197, 768])
# This is Batch x N_Patches+Class_Token x C * H_patch * W_patch
# Meaning   1   x   14*14  +     1      x 3 * 16* 16   
       
# However, if you actually print the shape in here you only get 196 in dim=1
# This means that the class token in missing
# The positional_embedding is done inside the encoder, so I guess should be fine

# The next code is just copy paste from the forward method in the source code
# for the vit_b_16 from pytorch in order to get the 

n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = model.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = encoder(x)

# Classifier "token" as used by standard language architectures

x = x[:, 0]

# Here you can use your own nn.Linear to map to your number of classes