How to Fine Tune own pytorch model

Hello there am a new to pytorch , my problem is I have to fine tune my own model . I have seen example of fine tuning the Torch Vision Models , like downloading the .pth and start training it.

Like wise I have my own .pth file and Neural Network model , I want to do fine tuning .
I kindly request you help with an example for my own model.

If you have your own .pth model file then just load it and finetune for the number of epochs you want.

import torch

model = get_model()
checkpoint = torch.load(path_to_your_pth_file)
model.load_state_dict(checkpoint['state_dict'])

finetune_epochs = 10 # number of epochs you want to finetune
for epoch in range(finetune_epochs):
         # assuming you have functions for training and validating models 
         train_model(model)   
         validate_model(model)
1 Like

hello there and thanks for the immediate reply .
I tried your above steps , I get the following error ,please help on this .

path_save="./model_new_10_300_0.005.pth"
model = torch.hub.load(‘pytorch/vision:v0.5.0’, ‘resnext50_32x4d’, pretrained=False)
for param in model.parameters():
param.requires_grad = False

model.fc = nn.Linear(2048, num_classes)
model
print(model)
checkpoint=torch.load(path_save)
model.load_state_dict(checkpoint[‘state_dict’])

And the error is
model.load_state_dict(checkpoint[‘state_dict’])
KeyError: ‘state_dict’

How to modify the number of categories in the last layer?

You would access the fully connected layer and just replace it with the output argument being the number of classes that your new task has. For example:

model = get_model()
checkpoint = torch.load(path_to_your_pth_file)
model.load_state_dict(checkpoint['state_dict'])
model.fc = nn.Linear(2048, 10) #input is whatever the output of prior layer is and output is the number of classes that you have

You can then decide how you would like to proceed with the training step since your last layer has random weights and the rest of your model has trained weights. You can either just train the last layer or the whole network.

1 Like

This is not working for me.

‘key error: state_dict’???

Any idea?

checkpoint['state_dict'] assumes that checkpoint is a dict with the state_dict key.
This doesn’t seem to be the case in your setup, so check what checkpoint is and what it contains.