Hi everyone,
I’ve been looking at previous posts regarding similar issues with understanding how to implement these masks, but things are still not clear to me for my use case.
I am training a Language Model with a transformer encoder-decoder. I know I can just use encoders for LMs, but I want to try this anyway and understand it.
My problem is that after 4 or 5 epochs the perplexity score drops very low, but at inference the model generates only the same last token from the sentence given.
My model is taking input from the get_batch
function from the seq2seq tutorial
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len]
return data, target
Then, I pass input and output to my model like this
data, targets = get_batch(train_data, i)
optimizer.zero_grad()
output = model(data, targets, batch=batch, epoch=epoch)
loss = criterion(output.view(-1, ntokens), targets.view(-1))
then the model processes the src and tgt like so
def forward(self, x, tgt, epoch=0, batch=0, flag="train"):
x = self.emb(x)
x = self.posenc(x)
tgt = self.posenc(self.emb(tgt))
x = self.transformer(x)#, generate_square_subsequent_mask(x.shape[0]))
self.trans_activation = x
if self.num_decoder_layers:
x = self.decoder(tgt, x, tgt_mask=generate_square_subsequent_mask(tgt.shape[0]), memory_mask=generate_square_subsequent_mask(x.shape[0]))
if not self.training and flag == "lollo":
self.save_activations(epoch=epoch, batch=batch)
return self.linear(x)
I have seen people shift the tgt sequence before passing it to the model and during the loss calculation, however it is not clear to me why this should be done since the tgt is already shifted and a mask is applied. Moreover a memory mask in my case is square since the seq len of my src and tgt are the same, right?
this is my inference function
def generate_sent(frase, model, TEXT, words=5):
"""
This function gets a list of a split string and
returns a new string with the next N number of words
predicted by the model
"""
for i in range(words):
frasevec = TEXT.numericalize([frase]).to("cuda")
model.eval()
var = model(frasevec, frasevec).squeeze(1)
print(var.shape)
most = torch.argmax(var[-1])
frase.append(TEXT.vocab.itos[most.cpu().detach().numpy()])
return " ".join(frase)