Optimizer issues with growing architecture

Hi, I am trying to build a self-attention module that uses a lookup table with song embeddings to generate representations of playlist embeddings. Training the model is smooth but when I try to load a checkpoint, I get issues in the optimizer:

ValueError: loaded state dict has a different number of parameter groups.

Here is a copy of my model, I have written it based on this previous post.

class Playlist_Constructor(nn.Module):
    def __init__(self, conf):
        super().__init__()
        self.method = conf["method"]
        self.num_layers = 1
        self.transformer_encoder = [TransformerEncoderLayer(conf).to(self.device) for _ in range(self.num_layers)]
        
    def forward(self, track_idx_seq, track_feat_seq, mask, length):
        # track_idx_seq: [bs, N_max]
        # track_feat_seq: [bs, N_max, dim]
        # mask: [bs, N_max]
        # length: [bs]
    
        feat = track_feat_seq
        for i in range(self.num_layers):
            feat = self.transformer_encoder[i](feat, mask)
        feat = feat.sum(-2) / length.unsqueeze(-1)
        return feat

I save the checkpoint using the following code:

def save_checkpoint(model, optimizer, config, epoch, checkpoint_path): 
    save_obj = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'config': config, 
        'epoch': epoch 
    }
    if config['cmdline']['fusion_method'] == "self_attn" and config['cmdline']['mode'] == "CIP":
        if hasattr(model, 'module'):
            save_obj['playlist_encoder'] = [i.state_dict() for i in model.module.playlist_constructor.transformer_encoder]
        else:
            save_obj['playlist_encoder'] = [i.state_dict() for i in model.playlist_constructor.transformer_encoder]

    torch.save(save_obj, os.path.join(checkpoint_path, 'checkpoint_%02d.pth'%epoch))

And here is the code I use for loading the optimizer checkpoint:

checkpoint = torch.load(args.checkpoint, map_location='cpu')
state_dict = checkpoint['model']    
model.load_state_dict(state_dict, strict=False)
for idx,i in enumerate(model.playlist_constructor.transformer_encoder):
        i.load_state_dict(checkpoint['playlist_encoder'][idx], strict=False)
optimizer_state_dict = checkpoint['optimizer'] 
optimizer = torch.optim.Adam(params=model.parameters(), lr=float(1e-4), betas=(0.9,0.999), eps=1e-08)
optimizer.load_state_dict(optimizer_state_dict)

My suspicion is that the optimizer is not storing the parameters correctly but I am not sure.

Thank you in advance for your help!

Could you show the part where your model is growing?
Also, why do you use strict=False? Do you expect to see unexpected or missing keys?

Could you show the part where your model is growing?
It seems that the optimizer assumes the model is learning but I am not trying to grow it.

Also, why do you use strict=False ? Do you expect to see unexpected or missing keys?
I thought this might resolve the issue but it did not. When I try removing it, the error persists.