Transfer learning using VGG16

I want to use VGG16 network for transfer learning. Following the transfer learning tutorial, which is based on the Resnet network, I want to replace the lines:

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

with their equivalent for VGG16.

My attempt is:

model_ft = models.vgg16(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

where, as far as I understand, the two lines in the middle are required in order to replace the classification process (from 10 classes, to 2). The problem is that the VGG16 class does not contain a “.fc” attribute, so running these lines results in an error.

What is the best way by which I can replace the corresponding lines in the Resnet transfer learning?

Since I am new in Pytorch (and Machine learning in general), any further (relevant) details regarding the structure of the VGG16 class (even details that are not necessarily required for the specific implementation I requested) will be gratefully appreciated.

Thanks!

For VGG16 you would have to use model_ft.classifier. You can find the corresponding code here.
Here is a small example how to reset the last layer. Of course you could also replace the whole classifier, if that’s what you wish.

model = models.vgg16(pretrained=False)
model.classifier[-1] = nn.Linear(in_features=4096, out_features=num_classes)
2 Likes

Many thanks ptrblck! For future reference, I also found this really helpful tutorial:
https://www.kaggle.com/carloalbertobarbano/vgg16-transfer-learning-pytorch

I have a similar question, but for the fcn resnet 101 segmentation model. In my case I am following this tutorial and I am trying to adapt this part of the code to fcn resnet 101.

The code from the tutorial is:

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

and so far I have changed it to:

if model_name == "resnet":

        """ FCN_resnet101
        """
        model_ft = models.segmentation.fcn_resnet101(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 768, 1024

When I do this I get this error: ‘FCN’ object has no attribute ‘fc’

So I was wondering how I can change the two lines below to work with the fcn segmentation model

        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)

Since this is a segmentation model, the output layer would be a conv layer instead of a linear one.
If you want to train your model from scratch, you could just use the num_classes argument:

modelA = models.segmentation.fcn_resnet101(
    pretrained=False, num_classes=2)

On the other hand, if you just want to use the pretrained model and create a new classification layer, you could use:

modelB = models.segmentation.fcn_resnet101(pretrained=True)
modelB.classifier[4] = nn.Conv2d(512, num_classes, 1, 1)

Thanks!

I am fine tuning a pretrained model with my own data, so the second method would work for me.

One follow up question:

Wouldn’t I have to fetch the number of in_channels of the existing pre-trained model, similarly to how its done in the example with ‘num_ftrs’?

So in the tutorial there is this line before creating a new layer:

num_ftrs = model_ft.fc.in_features

Would the equivalent for segmentation be the line below?:

in_chnls = modelB.classifier[4].in_channels

And then:

modelB.classifier[4] = nn.Conv2d(in_chnls, num_classes, 1, 1)

Yes, that would be the corresponding code.
You could also get the kernel_size and stride which are set as 1 in my code example.

Thanks! I am getting this part to work now!