How to use pretrained weights for initializing the weights in next iteration?

I have a model architecture. I have saved the entire model using torch.save() for some n number of iterations. I want to run another iteration of my code by using the pre-trained weights of the model I saved previously.

Edit: I want the weight initialization for the new iteration be done from the weights of the pretrained model

Right now, I do something like:

# default_lr = 5
# default_weight_decay = 0.001
# model_io = the pretrained model in .pth format
model = torch.load(model_io) 
optim = torch.optim.Adam(model.parameters(),lr=default_lr, weight_decay=default_weight_decay)  
loss_new = BCELoss()  
epochs = default_epoch 
.
.
training_loop():
....
outputs = model(input)
....
.
#similarly for test loop

The architecture and everything of my new iteration is same right now but might want to fiddle in future.

Am I missing something? I have to run for a very long epoch for a huge number of sample so can not afford to wait to see the results then figure out things.

Thank you!

1 Like

Where is the network here ?
I see only the weights stored in model.
You need to load your network with the model.
Typically, it goes this way,
#simple initialisation of your network. net = MyNet() #loading the pertained weights. weights= torch.load("*.pth*) #placing the weights inside your network. net.load_state_dict(weights) #using your pretrained network. outputs = net(input)

In your case, there’s no network defined.

1 Like

@chetan_patil
Oh! Makes sense. My understanding was incorrect. I thought torch.load would load the entire model along with the parameters, architecture and the weights.

It appears, it will only load the weight and not anything else. Is that correct?

Edit: Just to add, I don’t plan to resume training. I intend to save the model and use it for a separate training with same parameters. Think of it like using a saved model with weights etc for a larger run and more samples.

Yes, torch.load() loads weights, however, it also contains information about BatchNorm’s statistics, which you might need during the testing-phase.

Thanks @chetan_patil. As for my use case right now, I didn’t need batch normalization. Since my architecture is absolutely same, I think I should expect torch.save() to do the job i.e. using the saved model for another iteration for more number of samples.

Yes @copperwiring, torch.save() and torch.load() are the one.