LSTM Text generation Loss not decreasing

Hi all,
I just shifted from keras and finding some difficulty to validate my code. Currently I am training a LSTM network
for text generation on a character level but I observe that my loss is not decreasing. I am now doubting whether my model is wrongly built. Kindly someone help me with this.

Heres the code:

class CharLevelLanguageModel(torch.nn.Module):
def init(self,vocab_size,emb_dim,hidden_dim,batch_size):
super(CharLevelLanguageModel,self).init()
self.embedddings = torch.nn.Embedding(vocab_size,emb_dim,padding_idx=0)
self.lstm = torch.nn.LSTM(emb_dim,hidden_dim,1,batch_first=True)
self.linear = torch.nn.Linear(hidden_dim,vocab_size)
self.batch_size = batch_size
self.hidden_dim = hidden_dim
self.hidden_state = self.init_hidden()

def init_hidden(self):
    return (Variable(torch.zeros(1,self.batch_size,self.hidden_dim)))

def forward(self,x):
    embeds = self.embedddings(x)
    output,self.hidden_state = self.lstm(embeds)
    return F.log_softmax(self.linear(F.tanh(output[:,-1,:])))

train = data_utils.TensorDataset(torch.LongTensor(dataX),torch.LongTensor(dataY))
train_loader = data_utils.DataLoader(train,batch_size=100,drop_last=True)
model = CharLevelLanguageModel(len(char_to_id)+1,100,50,100)
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters())
losses =
for i in range(3):

total_loss = torch.FloatTensor([0])
for batch_idx,train in enumerate(train_loader):
    model.init_hidden()
    x,y = Variable(train[0]),Variable(train[1])
    y_pred = model(x)
    #print()
    loss = criterion(y_pred,y)
    total_loss+=loss.data
    loss.backward()
    optimizer.zero_grad()
    optimizer.step()
losses.append(total_loss)

Blockquote

The losses are:

[
6798.0005
[torch.FloatTensor of size 1],
6798.0005
[torch.FloatTensor of size 1],
6798.0005
[torch.FloatTensor of size 1]]

Zero the grad before the backward, bot after.

Best regards

Thomas

1 Like

Great . It looks like the error is decreasing now. Thanks a lot.
Other than that is the model built correctly?