I have just started learning Pytorch and I’m trying to implement a basic seq2seq MT solution but my loss seems to be oscillating instead of decreasing. I tried checking my inputs to the loss function and changing the learning rate but no change in output. The code below is modified from the machine translation tutorial of Pytorch.
import matplotlib.pyplot as plt
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
%matplotlib inline
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, source_vocab):
super(Encoder, self).__init__()
self.hidden_size = hidden_size
pad_idx = source_vocab.word2ix[source_vocab.pad]
self.embedding = nn.Embedding(num_embeddings=source_vocab.vocab, embedding_dim=hidden_size, padding_idx=pad_idx)
self.encoder = nn.LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True)
def forward(self, input, enc_hidden):
embeddings = self.embedding(input)
# print(f"Encoder embeddings: {embeddings.size()}")
embeddings = embeddings.unsqueeze(0)
enc_out, (last_hidden, last_cell) = self.encoder(embeddings, enc_hidden)
# print(f"Encoder out: {enc_out.size()}")
return enc_out, (last_hidden, last_cell)
class Decoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size, tgt_vocab):
super(Decoder, self).__init__()
self.hidden_size = hidden_size
pad_idx = tgt_vocab.word2ix[tgt_vocab.pad]
self.embedding = nn.Embedding(num_embeddings=tgt_vocab.vocab, embedding_dim=hidden_size, padding_idx=pad_idx)
self.h_proj = nn.Linear(in_features=2*hidden_size, out_features=hidden_size, bias=False)
self.c_proj = nn.Linear(in_features=2*hidden_size, out_features=hidden_size, bias=False)
self.decoder = nn.LSTM(input_size=input_size, hidden_size=hidden_size)
self.out = nn.Linear(input_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, init_dec):
embeddings = self.embedding(input).view(-1, 1, self.hidden_size)
# print(f"Decoder Embeddings: {embeddings.size()}")
if init_dec[0].size(0) == 2:
last_hidden, last_cell = init_dec
lh_f, lh_b = last_hidden[0], last_hidden[1]
lc_f, lc_b = last_cell[0], last_cell[1]
h_all = torch.cat([lh_f, lh_b], axis=1)
c_all = torch.cat([lc_f, lc_b], axis=1)
init_h = self.h_proj(h_all).unsqueeze(0)
init_c = self.c_proj(c_all).unsqueeze(0)
# print(f"Hidden states: h={init_h.size()}, c={init_c.size()}")
else:
init_h, init_c = init_dec
dec_out, dec_hidden = self.decoder(embeddings, (init_h, init_c))
# print(f"Decoder out: {dec_out.size()}")
output = self.softmax(self.out(dec_out[0]))
# print(f"Softmax out: {output.size()}")
return output, dec_hidden
Training:
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_src_len):
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
input_length = input_tensor.size(0)
target_length = target_tensor.size(0)
enc_out = torch.zeros(max_src_len, encoder.hidden_size, device=device)
h0 = torch.zeros(2, 1, encoder.hidden_size, device=device)
c0 = torch.zeros(2, 1, encoder.hidden_size, device=device)
enc_hidden = (h0, c0)
loss = 0
for ei in range(input_length):
enc_out, enc_hidden = encoder(input_tensor[ei], enc_hidden)
decoder_in = torch.tensor([[0]], device=device)
dec_hidden = enc_hidden
for di in range(target_length):
dec_out, dec_hidden = decoder(decoder_in, dec_hidden)
# print(f"Decoder out: {dec_out.size()}")
# Use when not teacher forcing
# topv, topi = dec_out.topk(1)
# decoder_in = topi.squeeze().detach().view(-1, 1)
# print(f"Shape before loss: dec={dec_out.size()}, t={target_tensor.size()}")
loss += criterion(dec_out, target_tensor[di])
decoder_in = target_tensor[di]
if decoder_in.item() == 2:
break
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
return loss.item() / target_length
def trainIters(encoder, decoder, n_iters, lr=0.0001):
total_loss = 0
encoder_optimizer = optim.SGD(encoder.parameters(), lr=lr)
decoder_optimizer = optim.SGD(decoder.parameters(), lr=lr)
train_pairs = pairs
criterion = nn.NLLLoss()
loss_stack = []
for i in range(n_iters):
training_pair = train_pairs[i]
ip_ten = training_pair[0]
tgt_ten = training_pair[1]
loss = train(ip_ten, tgt_ten, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_len_src)
loss_stack.append(loss)
return loss_stack