How to add fully connected layer in pretrained RESNET model in torch

I have a pretrained resnet152 model. It outputs 2048 dimensional feature vector. The code is given below.

import torch

import torch.nn as nn

import torchvision.models as models

import torchvision.transforms as transforms

from torch.autograd import Variable

from PIL import Image

# Load the pretrained model

model = models.resnet152(pretrained=True)

# Use the model object to select the desired layer

layer = model._modules.get('avgpool')

# Set model to evaluation mode

model.eval()

# Image transforms

scaler = transforms.Scale((224, 224))

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

to_tensor = transforms.ToTensor()

def get_vector(image_name):

    # 1. Load the image with Pillow library

    img = Image.open(image_name)

    # 2. Create a PyTorch Variable with the transformed image

    t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))

    # 3. Create a vector of zeros that will hold our feature vector

    #    The 'avgpool' layer has an output size of 2048

    my_embedding = torch.zeros(2048)

    # 4. Define a function that will copy the output of a layer

    def copy_data(m, i, o):

            my_embedding.copy_(o.data.reshape(o.data.size(1)))

    # 5. Attach that function to our selected layer

    h = layer.register_forward_hook(copy_data)

    # 6. Run the model on our transformed image

    model(t_img)

    print(model)

    # 7. Detach our copy function from the layer

    h.remove()

    # 8. Return the feature vector

    return my_embedding.numpy()

a=get_vector('/content/drive/MyDrive/images/COCO_val2014_000000000042.jpg')

print(a)

print(a.size)

I want 2048 dimensional feature vector that is returned by ResNet to be passed through a fully connected layer and reduce it to a 64 dimensional vector. How can I do that?

when you print the model (print(model)) you should see that there is a model.fc layer. You can make your new nn.Linear and assign it to model.fc.