Optimizer.load_state_dict() worsens training process when I continue it

I am training my Segformer model using AdamW optimizer, I trained the model for 100 epochs and then I saved the model and the optimizer state dict. At the time of saving the validation loss was 1.2

Now when I restart the training the validation loss begins with 30.2

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer=torch.optim.AdamW(model.parameters(),lr=5e-3,weight_decay=0.01)
criterion=nn.BCEWithLogitsLoss()
model.to(device)
model.load_state_dict(torch.load('segformer_full/epoch-100.pt',map_location='cuda'))
optimizer.load_state_dict(torch.load('segformer_optimizer/epoch-100.pt',map_location='cuda'))

Here are the training loop sections-

for epoch in range(num_epochs):
    model.train()
    train_loss=0
    for i in tqdm(range(len(datagen))):
     training loop....

    model.eval()
    valid_loss=0
    with torch.no_grad():
        for i in tqdm(range(len(datagen)-73)):
            validation loop....

     if valid_loss<best_loss:
        best_loss=valid_loss
        torch.save(model.state_dict(), os.path.join(r'segformer_full', 'epoch-{}.pt'.format(epoch)))
        torch.save(optimizer.state_dict(), os.path.join(r'segformer_optimizer', 'epoch-{}.pt'.format(epoch)))

Can anyone tell where I am going wrong?

Are you seeing the worse loss value only if you load the model’s state_dict and run the validation loop or do you need to perform a training step with the restored optimizer?

I would like to continue training my model

My suggestion wasn’t to change your workflow, but to use it as a debugging step to check if loading the state_dict of the model already causes the issue while the optimizer might be fine.

If I use optimizer.load_state_dict() alone, validation loss for the first 5 epochs is like this-
45.09
34.47
30.77
37.57
23.28

if I use both model.load_state_dict() and optimizer.load_state_dict() for the first 5 epochs validation loss are as follows-
28.53
39.14
28.32
26.47
21.64

Could you check loading the model’s state_dict alone and executing the validation phase as mentioned before?

using model.load_state_dict() alone it begins at 32.47 and then it was
14.4
15.94
7.11
10.34

So loading the model’s state_dict alone fails already as the last validation loss was:

At the time of saving the validation loss was 1.2

In that case I would recommend checking the output of the model before saving with a constant input (e.g. torch.ones) and after restoring it to make sure the parameters are properly loaded. If this fails already you would need to check if some layers are not properly restored by comparing their values.
However, if a static input yields the same results (up to floating point precision errors) you should check the data processing.

Yea, it works now. Not sure if the issue was with varying validation set or the state_dict was not saved properly.

Maybe the issue was that my validation set was varying (i.e, every epoch it generated a new set).
I fixed my validation set to a fixed number of samples and I tried again and it worked properly.

Thank You