Hello,
I experiment with simple encoder-decoder model and save/load functions.
I’ve managed to run the model, yet I have some issue with reloading it. After restarting IDE’s kernel (spyder) and loading model state, loss is even higher than in fresh (initialized) model.
I’m using Adam optimizer, and yes - I save it’s state just like rest of the model parameters.
With SGD no such issue is found.
Models:
class Encoder(nn.Module):
def __init__(self, voc_dim, hidden_dim):
super(Encoder, self).__init__()
self.hidden_dim = hidden_dim
self.embeddings = nn.Embedding(voc_dim , hidden_dim)
self.gru = nn.GRU(hidden_dim, hidden_dim)
def iniHidden(self):
return torch.zeros(1, 1, self.hidden_dim)
def forward(self, inputs, hidden):
embeds = self.embeddings(inputs).view(1, 1, -1)
out, hidden = self.gru(embeds, hidden)
return out, hidden
class Decoder(nn.Module):
def __init__(self, voc_dim, hidden_dim):
super(Decoder, self).__init__()
self.hidden_dim = hidden_dim
self.embeddings = nn.Embedding(voc_dim , hidden_dim)
self.gru = nn.GRU(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, voc_dim)
self.softmax = nn.LogSoftmax( dim = 1)
def iniHidden(self):
return torch.zeros(1, 1, self.hidden_dim)
def forward(self, input, hidden):
output = self.embeddings(input).view(1, 1, -1)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden
Save/Load functions :
def SaveModel(epo, enco, deco, enco_opti, deco_opti, losses, PATH):
chkpoint = {'Epoch': epo,
'Encoder state': enco.state_dict(),
'Decoder state': deco.state_dict(),
'Encoder Optimizer': enco_opti.state_dict(),
'Decoder Optimizer': deco_opti.state_dict(),
'Loss Curve': losses}
torch.save(chkpoint, 'chk_'+ PATH + '.pt')
print('Check point saved to chk_' + PATH + '.pt')
def LoadModel(enco, deco, enco_opti, deco_opti, PATH, test = False):
chkpoint = torch.load('chk_' + PATH + '.pt')
print('Loading chk_' + PATH)
epo = chkpoint['Epoch']
enco.load_state_dict(chkpoint['Encoder state'])
deco.load_state_dict(chkpoint['Decoder state'])
enco_opti.load_state_dict(chkpoint['Encoder Optimizer'])
deco_opti.load_state_dict(chkpoint['Decoder Optimizer'])
losses = chkpoint['Loss Curve']
if test :
enco.eval()
deco.eval()
return epo, enco, deco, enco_opti, deco_opti, losses
I’ve checked that all the parameters are loaded correctly.
Is anyone had similar issue?
Thanks,
Serge