Adam optimizer loading issues

Hello,

I experiment with simple encoder-decoder model and save/load functions.

I’ve managed to run the model, yet I have some issue with reloading it. After restarting IDE’s kernel (spyder) and loading model state, loss is even higher than in fresh (initialized) model.

I’m using Adam optimizer, and yes - I save it’s state just like rest of the model parameters.
With SGD no such issue is found.

Models:

class Encoder(nn.Module):
    
    def __init__(self, voc_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(voc_dim , hidden_dim)
        self.gru = nn.GRU(hidden_dim, hidden_dim)

    def iniHidden(self):
        return torch.zeros(1, 1, self.hidden_dim)
    
    def forward(self, inputs, hidden):
        embeds = self.embeddings(inputs).view(1, 1, -1)
        out, hidden = self.gru(embeds, hidden)
        return out, hidden

class Decoder(nn.Module):
    
    def __init__(self, voc_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(voc_dim , hidden_dim)
        self.gru = nn.GRU(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, voc_dim)
        self.softmax = nn.LogSoftmax( dim = 1)

    def iniHidden(self):
        return torch.zeros(1, 1, self.hidden_dim)
    
    def forward(self, input, hidden):
        output = self.embeddings(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

Save/Load functions :

def SaveModel(epo, enco, deco, enco_opti, deco_opti, losses, PATH):
    chkpoint =  {'Epoch': epo,
                 'Encoder state': enco.state_dict(),
                 'Decoder state': deco.state_dict(),
                 'Encoder Optimizer': enco_opti.state_dict(),
                 'Decoder Optimizer': deco_opti.state_dict(),
                 'Loss Curve': losses}
    torch.save(chkpoint, 'chk_'+ PATH + '.pt')
    print('Check point saved to chk_' + PATH + '.pt')
    

def LoadModel(enco, deco, enco_opti, deco_opti, PATH, test = False):
    chkpoint = torch.load('chk_' + PATH + '.pt')
    print('Loading chk_' + PATH)
    epo = chkpoint['Epoch']
    enco.load_state_dict(chkpoint['Encoder state'])
    deco.load_state_dict(chkpoint['Decoder state'])
    enco_opti.load_state_dict(chkpoint['Encoder Optimizer'])
    deco_opti.load_state_dict(chkpoint['Decoder Optimizer'])
    losses = chkpoint['Loss Curve']
    if test :
        enco.eval()
        deco.eval()
    return epo, enco, deco, enco_opti, deco_opti, losses

I’ve checked that all the parameters are loaded correctly.

Is anyone had similar issue?

Thanks,
Serge

The code looks generally fine.
When are you calling init_hidden in your training procedure?
Could this be an issue?

Thanks for quick reply.
I though it was an issue too. First, init_hidden was called once per training and saved with rest of the model. Later, I changed it to every epoch , just like in seq2seq tutorial.

This is my training routine :

    print('Starting loop')
    losses = []
        
    for epoch in range(ep):
        ep_start = time.time()
        total_loss = 0
        enc_optimizer.zero_grad()
        dec_optimizer.zero_grad()
        enc_hidden = enc.iniHidden()
        
        for context, target in loader:
            for con in context[0]:
                enc_out, enc_hidden = enc(con, enc_hidden)
        
            dec_hidden = enc_hidden
            dec_input = torch.tensor(9)
            
            for iii in range(len(target)):
                dec_out, dec_hidden = dec(dec_input, dec_hidden)
                top_val, top_idx = dec_out.topk(1)
                dec_input = top_idx.squeeze().detach()  # .detach() - Remove from gradient calculation
                total_loss += criterion(dec_out, target[iii].view(-1), )
        total_loss.backward()
        losses.append(total_loss.item()/len(Y))
        
        enc_optimizer.step()
        dec_optimizer.step()
        
        dec_hidden = dec_hidden.detach()
        enc_hidden = enc_hidden.detach()
        
        ep_end = time.time()
        ep_del = ep_end - ep_start 
        (mi,sec) = divmod(ep_del*(ep - epoch - 1),60)
        if epoch  == 0 :
            print('Estimated time is {} minutes.'.format(int(mi)))
            
        print('Epoch',epoch + 1 ,'/', ep, ', Loss :',round(total_loss.item()/len(Y), 4))
        if (epoch + 1) % 5 == 0:
            SaveModel(epoch + 1, enc, dec, enc_optimizer, dec_optimizer, losses, str(time.localtime(time.time()).tm_hour) + '_'
                      + str(time.localtime(time.time()).tm_min) + '_' + str(time.localtime(time.time()).tm_sec))

Thank you!