I have a model architecture. I have saved the entire model using
torch.save() for some n number of iterations. I want to run another iteration of my code by using the pre-trained weights of the model I saved previously.
Edit: I want the weight initialization for the new iteration be done from the weights of the pretrained model
Right now, I do something like:
# default_lr = 5
# default_weight_decay = 0.001
# model_io = the pretrained model in .pth format
model = torch.load(model_io)
optim = torch.optim.Adam(model.parameters(),lr=default_lr, weight_decay=default_weight_decay)
loss_new = BCELoss()
epochs = default_epoch
outputs = model(input)
#similarly for test loop
The architecture and everything of my new iteration is same right now but might want to fiddle in future.
Am I missing something? I have to run for a very long epoch for a huge number of sample so can not afford to wait to see the results then figure out things.
Where is the network here ?
I see only the weights stored in
You need to load your network with the
Typically, it goes this way,
#simple initialisation of your network. net = MyNet() #loading the pertained weights. weights= torch.load("*.pth*) #placing the weights inside your network. net.load_state_dict(weights) #using your pretrained network. outputs = net(input)
In your case, there’s no network defined.
Oh! Makes sense. My understanding was incorrect. I thought
torch.load would load the entire model along with the parameters, architecture and the weights.
It appears, it will only load the weight and not anything else. Is that correct?
Edit: Just to add, I don’t plan to resume training. I intend to save the model and use it for a separate training with same parameters. Think of it like using a saved model with weights etc for a larger run and more samples.
torch.load() loads weights, however, it also contains information about BatchNorm’s statistics, which you might need during the testing-phase.
Thanks @chetan_patil. As for my use case right now, I didn’t need batch normalization. Since my architecture is absolutely same, I think I should expect
torch.save() to do the job i.e. using the saved model for another iteration for more number of samples.
torch.load() are the one.