I’m writing a basic character level RNN trained on some English text. The model does reach very close to 0 loss when trained on a very very small training set of a few sentences, but it takes many epochs (~500). When I try it on a more complex dataset (~2-3 MB) the loss fluctuates widely even after training for several hours and the predictions are extremely repetitive, repeating the same word over and over.
Example of loss graph after running for 10 epochs:
And the model itself:
class LSTMModel(nn.Module):
def __init__(self,input_size, hidden_size):
super(LSTMModel,self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.h2o = nn.Linear(hidden_size,input_size)
self.input_size = input_size
self.hidden = None
def forward(self, input):
input = torch.nn.functional.one_hot(input, num_classes=self.input_size).type(torch.FloatTensor).unsqueeze(0)
input = input.to(device)
if self.hidden == None:
l_output, self.hidden = self.lstm(input)
else:
l_output, self.hidden = self.lstm(input,self.hidden)
self.hidden = (self.hidden[0].detach(), self.hidden[1].detach())
return self.h2o(l_output)
And the training loop:
dict_size = len(char_to_in)
hidden_size = 512
lr = 0.001
model = LSTMModel(dict_size, hidden_size)
model.to(device)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
losses = []
optimizer = torch.optim.Adam(model.parameters(),amsgrad=True,lr=lr)
def run_epoch():
i = 0
while i < len(text)-40:
seq_len = random.randint(10,40)
input, target = text[i:i+seq_len], text[i+1:i+1+seq_len]
i += seq_len
# forward pass
output = model(input)
loss = F.cross_entropy(output.squeeze()[-1:], torch.Tensor(target[-1:]).long())
# compute gradients
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Loss:",loss.item())
losses.append(loss.item())
I’m basing this off of the following tutorial, with some changes, who achieved much better results with a similar structure: An attempt at implementing char-rnn with PyTorch
I am very much a beginner at creating these sorts of models, so apologies if I’m doing anything basic wrong, and thanks everyone for the help.