Hi, sorry for the late reply.
I will recommend you another approach instead of fixing this, later if you want we can see how to fix this.
The problem
So, the problem is that when you use nn.Sequential
you get the nn.Modules
used inside of the architecture, but all functional API is gone. As you saw on the source code that I linked you, there are some stuff going on inside the forward
method of ViT that is not done (flattening + cls
token).
Possible solutions
As you have already tried, there are some ways of modifying the model to fit your needs, however it is sometimes problematic. Here are some possible solutions:
- Rewrite the code (more or less what you are doing)
- Add forward hooks (a little complicated but it works)
- and
- Use
torch.fx
feature extractor (I think this is my favorite)
Feature Extraction in TorchVision using Torch FX
In order to understan well how, why and what it’s good for, I really recommend this post. I will only explain the bare minimum for it to work.
The network
Here we do the same thing that you did before to load the network
network = getattr(torchvision.models,"vit_b_16")(pretrained=True)
The Feature Extractor
This is where the magic happens
from torchvision.models.feature_extraction import create_feature_extractor
feature_extractor = create_feature_extractor(network, return_nodes=['getitem_5'])
But now you are asking: "Matias, why did you use ‘getitem_5?’ and that is a great question.
Well you need to specify the return_nodes
. This means you can specify a list with many places where you want to interrupt your model and get the output for a given input. For this you need to know the exact node_name
.
Graph Node Names
In order to know how the node_names
are defined, you can use the following code
print(torchvision.models.feature_extraction.get_graph_node_names(network))
This will give you a long list of every node name like this one (these are just the last nodes)
So, as you can see, the last node before head
is called getitem_5
. Here is where we want to get the information.
Using the feature extractor
putting all of it together it would look something like this
from torchvision.models.feature_extraction import create_feature_extractor
network = getattr(torchvision.models,"vit_b_16")(pretrained=True)
feature_extractor = create_feature_extractor(network, return_nodes=['getitem_5'])
img = torch.rand(1, 3, 224, 224)
print(feature_extractor(img)['getitem_5'])
If you try this, you should not get any error. The output shape should be torch.Size([1, 768])
. You can now feed this to a MLP
or whatever you want to do with it.
Hope this helps and as I said, if you want to look at your approach anyways, then I can look at what went wrong, but I think this approach is easier and less prone to error.