Hello,
I have a time limit on the compute cluster, so I’m trying to train a model using multiple runs. This is what I do to save the model after the first run (R1)
torch.save(model.state_dict(), "trained.pth")
For the next run (R2), I load it using
model = myCustomNN()
model.to(torch.double)
model.to(device)
model.load_state_dict(torch.load("trained.pth", map_location=device))
model.train(True)
optimizer = Adam(model.parameters(), lr=learning_rate)
The train_loss has dropped from 10 to 1.e-2 during R1. However, when I start R2, the train loss seems to be 10 at the beginning and it seems as if the training is starting all over again. What am I doing wrong?
Edit: Wanted to add that the learning rate is the same across runs.
When you use Adam optimizer, it has estimators that need to be saved too.
Without those parameters, when you restart the run, the model diverges in one iteration creating the illusion you didn’t save the weights properly.
1 Like
Thanks! I tried changing the optimizer to SGD for R2 and it doesn’t seem to help. I guess I’ll have to start a new run and store the optimizer parameters as well.
Is there any way to salvage the model saved from R1? I was under the impression that the model state and optimizer and two mutually exclusive things and that I could merely restart a run with a new optimizer. Does this also mean that I cannot start R2 with a new learning rate?
They are mutually exclusive. If you restart R2 with SGD, it should work as long as you select a proper LR (which is typically different from the one that is good for Adam). Maybe you just need to find a good LR for that stage of you training with SGD.
Also, note that model.save_dict() is not a pytorch method. You may want to double-check what you are doing there.
Sorry about the typo. That was supposed to be model.state_dict()
.
Tried saving the optimizer state as well. I tried it with a shorter run and it doesn’t seem to help. This is how my code for R2 looks now
checkpoint = torch.load("trained.pth", map_location=device)
model = myCustomNN()
model.to(torch.double)
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer = Adam(model.parameters(), lr=learning_rate)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.train(True)
I am afraid I have no other simple suggestion to give you. That’s the cannonical way of working, as you can see Saving and Loading Models — PyTorch Tutorials 2.4.0+cu121 documentation
I can suggest you compare inference for a given sample right before saving and right after loading to check you are effectevely loading the weights.
Thanks for your time! I’ve done the inference check and the weights seem to be loaded fine.