AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 24, 31])

Hi, sorry for the late reply.

I will recommend you another approach instead of fixing this, later if you want we can see how to fix this.

The problem

So, the problem is that when you use nn.Sequential you get the nn.Modules used inside of the architecture, but all functional API is gone. As you saw on the source code that I linked you, there are some stuff going on inside the forward method of ViT that is not done (flattening + cls token).

Possible solutions

As you have already tried, there are some ways of modifying the model to fit your needs, however it is sometimes problematic. Here are some possible solutions:

  • Rewrite the code (more or less what you are doing)
  • Add forward hooks (a little complicated but it works)
  • and
  • Use torch.fx feature extractor (I think this is my favorite)

Feature Extraction in TorchVision using Torch FX

In order to understan well how, why and what it’s good for, I really recommend this post. I will only explain the bare minimum for it to work.

The network

Here we do the same thing that you did before to load the network

network = getattr(torchvision.models,"vit_b_16")(pretrained=True)

The Feature Extractor

This is where the magic happens

from torchvision.models.feature_extraction import create_feature_extractor
feature_extractor = create_feature_extractor(network, return_nodes=['getitem_5'])

But now you are asking: "Matias, why did you use ‘getitem_5?’ and that is a great question.

Well you need to specify the return_nodes. This means you can specify a list with many places where you want to interrupt your model and get the output for a given input. For this you need to know the exact node_name.

Graph Node Names

In order to know how the node_names are defined, you can use the following code

print(torchvision.models.feature_extraction.get_graph_node_names(network))

This will give you a long list of every node name like this one (these are just the last nodes)
image

So, as you can see, the last node before head is called getitem_5. Here is where we want to get the information.

Using the feature extractor

putting all of it together it would look something like this

from torchvision.models.feature_extraction import create_feature_extractor

network = getattr(torchvision.models,"vit_b_16")(pretrained=True)
feature_extractor = create_feature_extractor(network, return_nodes=['getitem_5'])

img = torch.rand(1, 3, 224, 224)
print(feature_extractor(img)['getitem_5'])

If you try this, you should not get any error. The output shape should be torch.Size([1, 768]). You can now feed this to a MLP or whatever you want to do with it.

Hope this helps :smile: and as I said, if you want to look at your approach anyways, then I can look at what went wrong, but I think this approach is easier and less prone to error.

2 Likes