Save the best model

    model = model.train()
    best_accuracy = 0
    ...
    for epoch in range(100):
        for idx, data in enumerate(data_loader):
                        ...
        if cur_accuracy > best_accuracy:
            best_model = model
    torch.save(best_model.state_dict(), 'model.pt')

In this way, the best accuracy model is saved well?

2 Likes

This code won’t work, as best_model holds a reference to model, which will be updated in each epoch.
You could use copy.deepcopy to apply a deep copy on the parameters or use the save_checkpoint method provided in the ImageNet example.
Here is a small example for demonstrating the issue with your code:

model = nn.Linear(10, 2)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(10):
    optimizer.zero_grad()
    output = model(torch.randn(1, 10))
    loss = criterion(output, torch.randn(1, 2))
    loss.backward()
    optimizer.step()
    
    # Save 2nd epoch
    if epoch == 2:
        best_model = model  # Won't work!
        #best_model = copy.deepcopy(model)  # Will work
        
# Compare models
for param1, param2 in zip(best_model.parameters(), model.parameters()):
    print((param1 == param2).all())
3 Likes

hi, I would like to know your code how to save the best model and the accuracy how to compare in different epochs?
thank you very much

Usually you would calculate the validation error/loss and save the best performing model (i.e. with the highest validation accuracy).
Have a look at the ImageNet example to see, how save_checkpoint is used for the best accuracy.

ok1,thank you very much

Hi, thanks for the graet answer.
I would like to know how to use the code about “compare models”, does it used for choosing the best trained model or just check if the two models are identical.
Thx!

My code snippet was just showing that the original code is not working, as no deepcopy was performed.
I would recommend to stick to the linked ImageNet example.

Honestly, this kind of stuff should be mentioned in the docs
Current example of storing the best model in the doc will lead to exactly this kind of bugs. And it happened already. We have been using the overfitted model in prod for months :frowning:

I’m sorry to hear you’ve had this trouble. :confused:
Would you be interested in adding this use case into the docs?

Yep. Created a merge request.

Here’s what works for me:

  model = model.train()
    best_accuracy = 0
    ...
    for epoch in range(100):
        for idx, data in enumerate(data_loader):
                        ...
        if cur_accuracy > best_accuracy:
                torch.save(model.state_dict(), 'best_model.pt')

Thanks for your code snippet. Do you continue training with this best model from here or just save it for using at last? What are the downsides of doing the former? @alx

Yes, it continues to the next epoch until it hits a better accuracy. If you use the same filename (here ‘best_model.pt’) it will replace the prior one. If you add a prefix ‘epoch_5_best_model.pt’ you will end up with a list of .pt files at the end of your run.

Personally I prefer keeping them all in case of overfitting.