Use pre-trained ViT as a backbone

Hi, I’m very new to this. I want to use the ViT B 16 pre-trained on ImageNet as backbone for the task of image classification on a different dataset. Given this trained backbone, the image representation is consequently used in combination with a kNN classifier. My code looks like this:

Initializing the model:

net = Embedder("vit_b_16", pretrained_flag = True)

The Embedder class:

class Embedder(nn.Module):

    def __init__(self,architecture,pretrained_flag = True):

        super(Embedder, self).__init__()

        #load the base model from PyTorch's pretrained models (imagenet pretrained)
        network = getattr(models,architecture)(pretrained=pretrained_flag)
        
        if architecture.startswith('vit'):
            self.backbone = network

    def forward(self, img):
        '''
        Output has shape: batch size x descriptor dimension
        '''
        x = self.backbone(img).squeeze(-1).squeeze(-1)

        return x

And here I extract the embeddings using torchvision’s feature_extraction tool and stopping at node getitem_5 (before the heads node):

def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):

    feature_extractor = create_feature_extractor(net, return_nodes=['backbone.getitem_5'])

    if verbose:

        if len(ms) == 1:
            print("Singlescale extraction")
        else:
            print("Multiscale extraction at scales: " + str(ms))
    
    net.eval()
    
    with torch.no_grad():

        vecs = np.zeros((768, len(dataloader.dataset))) #768 -> vb16; 1024 -> vl16
    
        for i,input in enumerate(dataloader):

            if len(ms) == 1 and ms[0] == 1:
                vecs[:, i] = feature_extractor(net,input[0])['backbone.getitem_5'] #.cuda()
            
            else:
                vecs[:, i] = feature_extractor(input[0])['backbone.getitem_5'] 

            if print_freq is not None:
                print("image: "+str(i))

    return vecs.T

Am I using the ViT correctly? Do I need to add anything to the forward function?