Hi, I am trying to build a self-attention module that uses a lookup table with song embeddings to generate representations of playlist embeddings. Training the model is smooth but when I try to load a checkpoint, I get issues in the optimizer:
ValueError: loaded state dict has a different number of parameter groups
.
Here is a copy of my model, I have written it based on this previous post.
class Playlist_Constructor(nn.Module):
def __init__(self, conf):
super().__init__()
self.method = conf["method"]
self.num_layers = 1
self.transformer_encoder = [TransformerEncoderLayer(conf).to(self.device) for _ in range(self.num_layers)]
def forward(self, track_idx_seq, track_feat_seq, mask, length):
# track_idx_seq: [bs, N_max]
# track_feat_seq: [bs, N_max, dim]
# mask: [bs, N_max]
# length: [bs]
feat = track_feat_seq
for i in range(self.num_layers):
feat = self.transformer_encoder[i](feat, mask)
feat = feat.sum(-2) / length.unsqueeze(-1)
return feat
I save the checkpoint using the following code:
def save_checkpoint(model, optimizer, config, epoch, checkpoint_path):
save_obj = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'config': config,
'epoch': epoch
}
if config['cmdline']['fusion_method'] == "self_attn" and config['cmdline']['mode'] == "CIP":
if hasattr(model, 'module'):
save_obj['playlist_encoder'] = [i.state_dict() for i in model.module.playlist_constructor.transformer_encoder]
else:
save_obj['playlist_encoder'] = [i.state_dict() for i in model.playlist_constructor.transformer_encoder]
torch.save(save_obj, os.path.join(checkpoint_path, 'checkpoint_%02d.pth'%epoch))
And here is the code I use for loading the optimizer checkpoint:
checkpoint = torch.load(args.checkpoint, map_location='cpu')
state_dict = checkpoint['model']
model.load_state_dict(state_dict, strict=False)
for idx,i in enumerate(model.playlist_constructor.transformer_encoder):
i.load_state_dict(checkpoint['playlist_encoder'][idx], strict=False)
optimizer_state_dict = checkpoint['optimizer']
optimizer = torch.optim.Adam(params=model.parameters(), lr=float(1e-4), betas=(0.9,0.999), eps=1e-08)
optimizer.load_state_dict(optimizer_state_dict)
My suspicion is that the optimizer is not storing the parameters correctly but I am not sure.
Thank you in advance for your help!