Torchvision's ViT question about number of classes for fine-tuning

Is this the right way I change num_classes in vit?

self.model = models.vit_b_16(pretrained=True)  # Load with 1000 classes first
        
if hasattr(self.model.heads, "head") and isinstance(self.model.heads.head, nn.Linear):
    self.model.heads.head = nn.Linear(self.model.heads.head.in_features, num_classes)

The pretrained=True wants to be replaced by an appropriate weights= parameter eventually, but other than that, it looks reasonable. I might change the if into an assert because, quite likely, you don’t want to continue if the head isn’t a linear that you can replace with your own.

Best regards

Thomas