Feature extraction in torchvision.models.vit_b_16

Hi

It’s easy enough to obtain output features from the CNNs in torchvision.models by doing this:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.resnet18()
feature_extractor = nn.Sequential(*list(model.children())[:-1])
output_features = feature_extractor(torch.randn(1, 3, 224, 224))

However, this does not work when I try it with torchvision.models.vit_b_16. It results in the following error:

AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 14, 14])

Any help would be greatly appreciated

Wrapping the submodules into an nn.Sequential container works for models which only use nn.Modules and no functional API calls and initialize them in a sequential manner.
In your use case you would lose all functional calls used in the original forward method as seen here.

1 Like

@ptrblck beat me to it :smile: .

Hope this helps.

import torch
import torch.nn as nn
import torchvision

img = torch.randn(1, 3, 224, 224)


model = torchvision.models.vit_b_16()
feature_extractor = nn.Sequential(*list(model.children())[:-1])

# This is supposed to be the PREPROCESS
# But it is not done correctly, since the reshaping and permutation is not done
# Only the concolution
conv = feature_extractor[0]  

# -> print(conv(img).shape)
# -> torch.Size([1, 768, 14, 14])
# This is not the desired output after preprocessing the image into
# flat patches. Also in the pytorch implementation, the class token
# and positional embedding are done extra on the forward method.

# This is the whole encoder sequence
encoder = feature_extractor[1]

# The MLP head at the end is gone, since you only selected the children until -1
# mlp = feature_extractor[2]

# This is how the model preprocess the image.
# The output shape is the one desired 
x = model._process_input(img)

# -> print(x.shape)
# -> torch.Size([1, 197, 768])
# This is Batch x N_Patches+Class_Token x C * H_patch * W_patch
# Meaning   1   x   14*14  +     1      x 3 * 16* 16   
       
# However, if you actually print the shape in here you only get 196 in dim=1
# This means that the class token in missing
# The positional_embedding is done inside the encoder, so I guess should be fine

# The next code is just copy paste from the forward method in the source code
# for the vit_b_16 from pytorch in order to get the 

n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = model.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = encoder(x)

# Classifier "token" as used by standard language architectures

x = x[:, 0]

# Here you can use your own nn.Linear to map to your number of classes
1 Like

@Matias_Vasquez @ptrblck
Thanks so much. So the “representation” for that image in the ViT model is the first vector extracted using x=x[:, 0] line?

Thanks a lot!

Yes, the class token in the first row is where the classification is going to be made. However, this still has the size of the dimension of one patch.

This means that if each patch is 3*16*16 (CxHxW), then the size of x is

# 14*14 patches
# 16*16 pixels in each patch

x.shape # -> 197x768 = 14*14+1 x 16*16
x = x[:, 0]
x.shape # -> 1x768

So afterwards comes the MLP after the encoder to reduce this to the number of classes you want. Some implementations use a couple of layers, but a single nn. Linear(768, num_classes) should be enough.

@Matias_Vasquez

Many thanks for your help. I was getting confused as to why one doesn’t take the mean of x when it has shape 197x768 to produce 1x768. But I understand now why we instead take the first element to produce 1x768.

Thanks!

1 Like

Hello, do you have a solution for this problem? Thank you!

@Matias_Vasquez explained the proper approach really well. :slight_smile: Could you explain where you are stuck in this explanation?

1 Like

I am having a similar challenge. I am trying to use vit_b_16 as my backbone for faster rcnn.

class VITWithFPN(torch.nn.Module):
def init(self):
super(VITWithFPN, self).init()
# Get a VIT backbone
self.model =models.vit_b_16(pretrained=True,image_size=224)

    #Extract the encoder      
    self.body = create_feature_extractor(
        self.model, return_nodes=['encoder'])
    
    inp = torch.randn(2, 3, 224, 224)
   
    with torch.no_grad():
        out = self.body(inp)
    in_channels_list = [o.shape[1] for o in out.values()]
    # Build FPN
    self.out_channels = self.model.hidden_dim
    self.fpn = torchvision.ops.FeaturePyramidNetwork(
        in_channels_list, out_channels=self.out_channels,
        extra_blocks=LastLevelMaxPool())

def _process_input(self, x: torch.Tensor) -> torch.Tensor:
    n, c, h, w = x.shape
    print (n,c,h,w)
    print(x.shape)
    
    p = self.model.patch_size
    torch._assert(h == self.model.image_size, "Wrong image height!")
    torch._assert(w == self.model.image_size, "Wrong image width!")
    n_h = h // p
    n_w = w // p

    # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
    x = self.model.conv_proj(x)
    # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
    x = x.reshape(n, self.model.hidden_dim, n_h * n_w)

    # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
    # The self attention layer expects inputs in the format (N, S, E)
    # where S is the source sequence length, N is the batch size, E is the
    # embedding dimension
    x = x.permute(0, 2, 1)

    return x

def forward(self, x):
    print("Image is -->",x.size())
    x = self._process_input(x)
    n = x.shape[0]
    # Expand the class token to the full batch
    batch_class_token = self.model.class_token.expand(n, -1, -1)
    x = torch.cat([batch_class_token, x], dim=1)
    x = self.body(x)
    x = self.fpn(x)
    return x

#Build model
model = FasterRCNN(VITWithFPN(), num_classes=2)

I resized the images in my dataloader to 224*224. However,before the image gets through the forward method, the image tensors gets transformed and have dimension torch.Size([8, 3, 800, 800])
instead of torch.Size([3, 224, 224]). I would appreciate your help.
Thanks in advance @ptrblck @Matias_Vasquez

Hi,

sorry for the late reply. Could you where the shape changes? Or where this error occurs?

The shape changes before the first line of code in forward function

Ok,

since you are printing the shape of x in the first line of your forward function, this means that the ViT is NOT changing the size of your image, but whatever you are doing before.

I assume that the change in size is happening in your FasterRCNN before you feed it to the ViT.

If you post the code on how you create this FasterRCNN, then I maybe I can help you a little more.

import os
import configparser
import torchvision
import csv
import os
import os.path as osp
import pickle
from PIL import Image
import numpy as np
import scipy
import torch
import matplotlib.pyplot as plt
from typing import overload

import torchvision.transforms as T

import utils
from dataloader import MOT17ObjDetect
from engine import train_one_epoch, evaluate
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
import torchvision.models as models
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
import torchvision.models as models
import torchvision
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.models.detection import FasterRCNN
import torch.nn as nn

from dataloader import MOT17ObjDetect
import torchvision
OUTPUT_DIR=“outputs”
LOG_RESULTS =“inference”
PREDICTED_IMAGES =“images_with_boxes”
import random
from torchvision.models.vision_transformer import vit_b_16

def plot(img, boxes):
x=random.randint(100000, 100000000)
fig, ax = plt.subplots(1, dpi=96)

img = img.mul(255).permute(1, 2, 0).cpu().byte().numpy()
width, height, _ = img.shape

ax.imshow(img, cmap=‘gray’)
fig.set_size_inches(width / 80, height / 80)

for box in boxes:
rect = plt.Rectangle(
(box[0], box[1]),
box[2] - box[0],
box[3] - box[1],
fill=False,
linewidth=1.0)
ax.add_patch(rect)

plt.axis(‘off’)
plt.savefig("./images_with_boxes/"+str(x)+".png")

class VITWithFPN(torch.nn.Module):
def init(self):
super(VITWithFPN, self).init()
# Get a VIT backbone
self.model =models.vit_b_16(pretrained=True,image_size=224)

    #Extract the encoder      
    self.body = create_feature_extractor(
        self.model, return_nodes=['encoder'])
    
    inp = torch.randn(2, 3, 224, 224)
   
    with torch.no_grad():
        out = self.body(inp)
    in_channels_list = [o.shape[1] for o in out.values()]
    # Build FPN
    self.out_channels = self.model.hidden_dim
    self.fpn = torchvision.ops.FeaturePyramidNetwork(
        in_channels_list, out_channels=self.out_channels,
        extra_blocks=LastLevelMaxPool())

def _process_input(self, x: torch.Tensor) -> torch.Tensor:
    n, c, h, w = x.shape
    print (n,c,h,w)
    print(x.shape)
    
    p = self.model.patch_size
    torch._assert(h == self.model.image_size, "Wrong image height!")
    torch._assert(w == self.model.image_size, "Wrong image width!")
    n_h = h // p
    n_w = w // p

    # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
    x = self.model.conv_proj(x)
    # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
    x = x.reshape(n, self.model.hidden_dim, n_h * n_w)

    # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
    # The self attention layer expects inputs in the format (N, S, E)
    # where S is the source sequence length, N is the batch size, E is the
    # embedding dimension
    x = x.permute(0, 2, 1)

    return x

def forward(self, x):
    print("Image is -->",x.size())
    x = self._process_input(x)
    n = x.shape[0]
    # Expand the class token to the full batch
    batch_class_token = self.model.class_token.expand(n, -1, -1)
    x = torch.cat([batch_class_token, x], dim=1)
    x = self.body(x)
    x = self.fpn(x)
    return x

#Build model
model = FasterRCNN(VITWithFPN(), num_classes=2)

def get_transform(train):
transforms = []
# converts the image, a PIL image, into a PyTorch Tensor
transforms.append(T.ToTensor())
if train:
# during training, randomly flip the training images
# and ground-truth for data augmentation
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.Resize((224,224)))
return T.Compose(transforms)

use our dataset and defined transformations

dataset = MOT17ObjDetect(’./data/MOT17Det/train’, get_transform(train=True))
dataset_no_random = MOT17ObjDetect(’./data/MOT17Det/train’, get_transform(train=False))
dataset_test = MOT17ObjDetect(’./data/MOT17Det/test’, get_transform(train=False))

split the dataset in train and test set

torch.manual_seed(1)

define training and validation data loaders

data_loader = torch.utils.data.DataLoader(
dataset, batch_size=8, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
data_loader_no_random = torch.utils.data.DataLoader(
dataset_no_random, batch_size=8, shuffle=False, num_workers=4,
collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=2, shuffle=False, num_workers=4,
collate_fn=utils.collate_fn)

device = torch.device(‘cuda’) if torch.cuda.is_available() else torch.device(‘cpu’)

get the model using our helper function

model = get_detection_model(dataset.num_classes)

move model to the right device

model.to(device)

construct an optimizer

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.00125,
momentum=0.9, weight_decay=0.0005)

and a learning rate scheduler which decreases the learning rate by

10x every 3 epochs

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=10,
gamma=0.1)
os.makedirs(OUTPUT_DIR,exist_ok=True)
os.makedirs(LOG_RESULTS,exist_ok=True)
def evaluate_and_write_result_files(model, data_loader):
model.eval()
results = {}
for imgs, targets in data_loader:
imgs = [img.to(device) for img in imgs]

with torch.no_grad():
    preds = model(imgs)

for pred, target in zip(preds, targets):
    results[target['image_id'].item()] = {'boxes': pred['boxes'].cpu(),
                                          'scores': pred['scores'].cpu()}

data_loader.dataset.print_eval(results)
data_loader.dataset.write_results_files(results, OUTPUT_DIR+"/")

num_epochs = 27

for epoch in range(1, num_epochs + 1):
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=300)
# update the learning rate
# lr_scheduler.step()
# evaluate on the test dataset
if epoch % 3 == 0:
evaluate_and_write_result_files(model, data_loader_no_random)
torch.save(model.state_dict(), f"{OUTPUT_DIR}/model_epoch_{epoch}.model")

# pick one image from the test set

os.makedirs(PREDICTED_IMAGES,exist_ok=True)
dataset = MOT17ObjDetect(’./data/MOT17Det/train’, get_transform(train=False))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=20, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)

print("---------------Done training------------------")
for imgs, target in data_loader:

# put the model in evaluation mode
model.eval()



with torch.no_grad():
    for k in range(0,20):
        prediction = model([imgs[k].to(device)])[0]
        

        plot(imgs[k], prediction['boxes'].cpu())

@Matias_Vasquez ,here is the complete code file

Since you are using this ↓ to declare your FasterRCNN, you can use one of the Arguments to make sure that the image has the right size before being fed to your backbone.

Here is the documentation so you can see all of the arguments that you can use for this object. The default min_size is 800. If you define a max_size=224 this should be solved.

model = FasterRCNN(VITWithFPN(), num_classes=2, max_size=224)

However I think there are still some issues with the Feature Pyramid.
Also you are not selecting the class token as explained in the previous posts, but if that is what you want then you can try it out.

Hope this helps :smile:

Yes,there is still an issue. Please could you point out the error with my feature pyramid network?

Sorry but I do not quite understand HOW you want to use the FPN with the ViT.

If you declare your body like this ↓ and you print it, you will see that you do not have only the encoder. You have the entire ViT: preprocess, encoder and head (LayerNorm).

This means that when you do something like this ↓ the input will pass through the entire ViT and return when you finish with the encoder (before the head). This means you do not have to do the preprocess yourself when doing the forward.

As mentioned, the output for this is the output of the encoder. Which has shape torch.Size([1, 197, 768]). As mentioned in previous posts, here is where you normally take the class token and pass it through a MLP to get the classification.

This is only ONE output tensor, and not the entire encoder sequence. So for your FPN you are only using this output to define your in_channels_list.

So, it really depends how you want to use your FPN and if it really is the way to go, instead of a normal MLP. But I do not know.

Hello @Matias_Vasquez ,thanks so much. I am really new to using ViT as a backbone. Please how will I adapt the MLP to use it in my FPN? Thanks in advance

You can use another two way to the same:
1°: Feature extraction for model inspection — Torchvision 0.12 documentation (pytorch.org)

import torch
import torchvision.models as models
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
# print(get_graph_node_names(model) )

model = models.vit_b_16()
x = torch.randn(1, 3, 224, 224)
output_feature = create_feature_extractor(model, return_nodes = { "encoder.ln":"getitem_5"} )
output_feature(x)

2°: torch.nn.modules.module.register_module_forward_hook — PyTorch 1.11.0 documentation

import torch
import torchvision.models as models
model = models.vit_b_16()
def print_middle_layer(model,input,output):
    print("Print Output:", output)

model.encoder.ln.register_forward_hook( print_middle_layer ) 
# nn.Sequential(*list(model.children())[:-1]) == model.encoder.ln
x = torch.randn(1, 3, 224, 224)
model(x)

Thanks,basically,what I am trying to do is to use vit_b_16() as my feature extractor for Faster RCNN instead of Resnet architecture (ResNet50,ResNet101,and ResNet152). I just runned your code @diegoaichele and I am getting some errors.
I want to use vit_b_16() as my backbone for faster rcnn. Maybe if you could help me do that.