Tgt and memory masks for transformer decoder layers

Hi everyone,

I’ve been looking at previous posts regarding similar issues with understanding how to implement these masks, but things are still not clear to me for my use case.

I am training a Language Model with a transformer encoder-decoder. I know I can just use encoders for LMs, but I want to try this anyway and understand it.

My problem is that after 4 or 5 epochs the perplexity score drops very low, but at inference the model generates only the same last token from the sentence given.

My model is taking input from the get_batch function from the seq2seq tutorial

def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len]
    return data, target

Then, I pass input and output to my model like this

        data, targets = get_batch(train_data, i)
        output = model(data, targets, batch=batch, epoch=epoch)
        loss = criterion(output.view(-1, ntokens), targets.view(-1))

then the model processes the src and tgt like so

    def forward(self, x, tgt, epoch=0, batch=0, flag="train"):
        x = self.emb(x)
        x = self.posenc(x)
        tgt = self.posenc(self.emb(tgt))
        x = self.transformer(x)#, generate_square_subsequent_mask(x.shape[0]))
        self.trans_activation = x
        if self.num_decoder_layers:
            x = self.decoder(tgt, x, tgt_mask=generate_square_subsequent_mask(tgt.shape[0]), memory_mask=generate_square_subsequent_mask(x.shape[0]))
        if not and flag == "lollo":
            self.save_activations(epoch=epoch, batch=batch)
        return self.linear(x)

I have seen people shift the tgt sequence before passing it to the model and during the loss calculation, however it is not clear to me why this should be done since the tgt is already shifted and a mask is applied. Moreover a memory mask in my case is square since the seq len of my src and tgt are the same, right?

this is my inference function

def generate_sent(frase, model, TEXT, words=5):
    This function gets a list of a split string and 
    returns a new string with the next N number of words 
    predicted by the model
    for i in range(words):
        frasevec = TEXT.numericalize([frase]).to("cuda")
        var = model(frasevec, frasevec).squeeze(1)
        most = torch.argmax(var[-1])
    return " ".join(frase)