Help Needed: Transformer Model Repeating Last Token During Inference

Hi everyone,

I’ve been trying to learn about transformers and wanted to start by trying to implement PyTorch’s nn.TransformerEncoder and nn.TransformerDecoder solutions into a simple model.
But I’m running into a consistent issue that I’m unable to resolve where during inference the model only produces the last token fed into it.
For example lets say I have a tensor [1,2,3,4,5] the model will continue the sequence with [1,2,3,4,5,5,5,5,5,5,…] or if I had [5,2,8,3] it would continue to produce [5,2,8,3,3,3,3,3,3,3,…].

Although it produces the above results the loss continues to decrease as I train it as if its managing to learn the dataset. So initially I thought this was just a problem with the dataset where the target was the same as the input which would cause it to produce the same tokens, but after further testing I’m 100% sure that the targets are defiantly the next token in the sequence for example the input would be [1,2,3,4] and the target would be [2,3,4,5].

After this I was left confused and didn’t know what to try next so I went to research and try to implement different implementations of the common components such as positional encoding and adjusting hyper-parameters. but regardless still weeks later and I’m still zero progress towards identifying the issue.

So now I’m at the point where its getting frustration and I don’t think I can solve the problem on my own given my limited knowledge which is why I’m asking for help here.

For reference here is the model and training step I’m using:

class TextEmbedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, padding_index: int):
        super(TextEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=padding_index)

    def forward(self, x):
        return self.embedding(x)

class TextTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim = 512, nhead = 8, num_encoder_layers = 6, num_decoder_layers = 6, max_length = 5000, padding_index = 0):
        super(TextTransformer, self).__init__()
        self.vocab_size = vocab_size
        self.max_length = max_length

        self.text_embedding = TextEmbedding(vocab_size, embed_dim, padding_index)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_length, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=2048)
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=2048)
        self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_decoder_layers)

        self.fc = nn.Sequential(
            nn.Linear(embed_dim, vocab_size)
        )

    def forward(self, src, tgt, src_mask, tgt_mask):
        #Embedding + Positional Encoding
        src_embedding = self.text_embedding(src) + self.positional_encoding[:, :src.size(1), :]
        tgt_embedding = self.text_embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]

        tgt_square_mask = create_square_mask(tgt.size(1)).to(src.device)

        #Encoder
        memory = self.encoder(src_embedding.permute(1, 0, 2), src_key_padding_mask=src_mask)

        #Decoder
        decoder_out = self.decoder(tgt_embedding.permute(1, 0, 2), memory, tgt_mask=tgt_square_mask, tgt_key_padding_mask=tgt_mask)
        decoder_out = decoder_out.permute(1, 0, 2)

        #FC output
        output = self.fc(decoder_out)

        return output

    def seq2seq(self, src, src_mask, stop_token, max_length = 500):
        src_embedding = self.text_embedding(src) + self.positional_encoding[:, :src.size(1), :]

        memory = self.encoder(src_embedding.permute(1, 0, 2), src_key_padding_mask=src_mask)
        sequence = src
        stop = False

        while sequence.shape[1] < min(self.max_length, max_length) and not stop:
            tgt_embedding = self.text_embedding(sequence) + self.positional_encoding[:, :sequence.size(1), :]

            tgt_square_mask = create_square_mask(sequence.size(1)).to(src.device)
            dec_output = self.decoder(tgt_embedding.permute(1, 0, 2), memory, tgt_mask=tgt_square_mask)
            dec_output = dec_output.permute(1, 0, 2)

            out = self.fc(dec_output)[:, -1, :]
            predicted = out.argmax(dim=1)
            
            if predicted.item() == stop_token:
                stop = True

            sequence = torch.cat((sequence, predicted.unsqueeze(dim=0)),dim=1)

        return sequence

    def create_square_mask(size):
    	mask = torch.triu(torch.ones(size, size), diagonal=1)
    	mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
    	return mask

def train_step(model, dataloader, criterion, optimizer, device):
    avg_loss = 0
    model.train()
    for batch, (text_data, text_pad_mask) in enumerate(dataloader):
        text_data, text_pad_mask = text_data.to(device), text_pad_mask.to(device)

        #shift data so that the in_text is the initial tokens and that tgt_text is the next predicted token in the sequence
        in_text = text_data[:, :-1]
        in_mask = text_pad_mask[:, :-1]
        tgt_text = text_data[:, 1:]
        tgt_mask = text_pad_mask[:, 1:]


        out = model(in_text, tgt_text, in_mask, tgt_mask)

        outputs = out[:, :].reshape(-1, model.vocab_size)# Reshape to [batch_size * steps, vocab_size]
        targets = tgt_text[:, :].reshape(-1)# Reshape to [batch_size * steps]

        loss = criterion(outputs, targets)
        avg_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return avg_loss / len(dataloader)

I think this is all that is necessary to try diagnose the issue as I’m 100% sure the tokenizer and data loader is working perfectly as I’ve done a lot of testing on them and don’t want to flood this post with too much code but I can provide the code for them upon request if it helps at all.

If anyone could help me with this problem it would be massively appreciated as this has been something which has stumped me for weeks now.

Thanks for your time :slight_smile: