How to perform finetuning in Pytorch?

@zhoubinxyz may you add a more complete example to illustrate how to fine-tuning VGG model on a custom dataset using PyTorch?

1 Like

You can get the basic idea form this PR.

1 Like

There is an option named --pretrained in the imagenet main.py file. May I ask that if I use this option with a custom dataset like these:
python main.py --arch=alexnet --pretrained my_custom_dataset
What will happen with this command? It seems that like a fine-tuning.

I think --pretrained is meant for evaluation mode. The script doesn’t support finetuning at the moment.

1 Like

@apaszke I reference this PR for fine-tuning. For alexnet and vggnet, the original code replay all the fully-connected layers. May I ask:

  • how can I only replace the last fully-connected layer for fine-tuning and freeze other fully-connected layers?
  • Is the forward the right way to code? Because you give some reference code above:

def forward(self, x):
return self.last_layer(self.pretrained_model(x))

Original fine-tuing code:

class FineTuneModel(nn.Module):
    def __init__(self, original_model, arch, num_classes):
        super(FineTuneModel, self).__init__()

        if arch.startswith('alexnet') :
            self.features = original_model.features
            self.classifier = nn.Sequential(
                nn.Dropout(),
                nn.Linear(256 * 6 * 6, 4096),
                nn.ReLU(inplace=True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(inplace=True),
                nn.Linear(4096, num_classes),
            )
            self.modelName = 'alexnet'
        elif arch.startswith('resnet') :
            # Everything except the last linear layer
            self.features = nn.Sequential(*list(original_model.children())[:-1])
            self.classifier = nn.Sequential(
                nn.Linear(512, num_classes)
            )
            self.modelName = 'resnet'
        elif arch.startswith('vgg16'):
            self.features = original_model.features
            self.classifier = nn.Sequential(
                nn.Dropout(),
                nn.Linear(25088, 4096),
                nn.ReLU(inplace=True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(inplace=True),
                nn.Linear(4096, num_classes),
            )
            self.modelName = 'vgg16'
        else :
            raise("Finetuning not supported on this architecture yet")

        # Freeze those weights
        for p in self.features.parameters():
            p.requires_grad = False


    def forward(self, x):
        f = self.features(x)
        if self.modelName == 'alexnet' :
            f = f.view(f.size(0), 256 * 6 * 6)
        elif self.modelName == 'vgg16':
            f = f.view(f.size(0), -1)
        elif self.modelName == 'resnet' :
            f = f.view(f.size(0), -1)
        y = self.classifier(f)
        return y
1 Like

This post should help you.

1 Like

I added following lines to imagenet example, using pretrained model of resnet18.

 for param in model.parameters():
      param.requires_grad = False

 # Replace the last fully-connected layer
 # Parameters of newly constructed modules have requires_grad=True by default
 model.fc = torch.nn.Linear(512, 3)

 optimizer = torch.optim.SGD(model.fc.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

But then I have following error:

File "main.py", line 234, in train
    loss.backward()
File "/usr/local/lib/python2.7/dist-packages/torch/autograd/variable.py", line 146, in backward
    self._execution_engine.run_backward((self,), (gradient,), retain_variables)
RuntimeError: there are no graph nodes that require computing gradients

I would like to freeze all parameters of original ResNet18 and just learn the last layer with 3 classes. How I should do this correctly? Based on information from the forum, this should we the working version.

That should work. Can you post the entire code, just to check if there is some error there and maybe trying to run it here?

Here is the full code:
http://pastebin.com/g6xxBDmr

It is original ImageNet example with some elements of Visdom, which I tried to use. So running it on your own should stat Visdom server or delete some lines.

Hello,

Do you know how can I change the cost function while finetuning a pre-trained model (like ResNet-18, VGG-16) or create a customized cost function and use it?

Thanks.

the cost function lies outside the model definition, change it like you would change the cost function as usual.

What is the learning rate for base_params and fc.parameters() in this example How to perform finetuning in Pytorch?? @apaszke

@apaszke Do you know if the underlying code has changed at all since your posted this? I get

AttributeError: type object 'object' has no attribute '__getattr__'

when trying

optimizer = torch.optim.SGD(model.fc.parameters(), opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)

@achaiah Are you using master ?

There was a bad commit I think during last week, I got AttributeError for previously working examples and had to rollback. It was fixed in the past couple days it seems. Try to update.

My error for reference (ignore torchsample warning it’s always there):

Aha, you’re correct. I updated to the latest version from pytorch.org and the error has gone away.

1 Like

Can you out the whole code for finetuning here, so we can get the benefit from it please?

For future readers of this thread: there’s a tutorial on transfer learning in official pytorch tutorials.

1 Like

for someone like me that is a newbie, that tutorial is confusing and not helpful…

Let me know if you face any problem.

1 Like

How to pause and resume training in Pytorch?, suppose I have train until epoch 2000, and I want to continue until epoch 4000. I just think that first we load the weight and finetune it, but my loss is always same. Is it need to seet require_grad=true to all parameter to resume training? How to do resume training right way in pytorch?