Help Needed: Transformer Model Repeating Last Token During Inference

Hi everyone,

I’ve been trying to learn about transformers and wanted to start by trying to implement PyTorch’s nn.TransformerEncoder and nn.TransformerDecoder solutions into a simple model.
But I’m running into a consistent issue that I’m unable to resolve where during inference the model only produces the last token fed into it.
For example lets say I have a tensor [1,2,3,4,5] the model will continue the sequence with [1,2,3,4,5,5,5,5,5,5,…] or if I had [5,2,8,3] it would continue to produce [5,2,8,3,3,3,3,3,3,3,…].

Although it produces the above results the loss continues to decrease as I train it as if its managing to learn the dataset. So initially I thought this was just a problem with the dataset where the target was the same as the input which would cause it to produce the same tokens, but after further testing I’m 100% sure that the targets are defiantly the next token in the sequence for example the input would be [1,2,3,4] and the target would be [2,3,4,5].

After this I was left confused and didn’t know what to try next so I went to research and try to implement different implementations of the common components such as positional encoding and adjusting hyper-parameters. but regardless still weeks later and I’m still zero progress towards identifying the issue.

So now I’m at the point where its getting frustration and I don’t think I can solve the problem on my own given my limited knowledge which is why I’m asking for help here.

For reference here is the model and training step I’m using:

class TextEmbedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, padding_index: int):
        super(TextEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=padding_index)

    def forward(self, x):
        return self.embedding(x)

class TextTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim = 512, nhead = 8, num_encoder_layers = 6, num_decoder_layers = 6, max_length = 5000, padding_index = 0):
        super(TextTransformer, self).__init__()
        self.vocab_size = vocab_size
        self.max_length = max_length

        self.text_embedding = TextEmbedding(vocab_size, embed_dim, padding_index)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_length, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=2048)
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_encoder_layers)

        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=2048)
        self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=num_decoder_layers)

        self.fc = nn.Sequential(
            nn.Linear(embed_dim, vocab_size)
        )

    def forward(self, src, tgt, src_mask, tgt_mask):
        #Embedding + Positional Encoding
        src_embedding = self.text_embedding(src) + self.positional_encoding[:, :src.size(1), :]
        tgt_embedding = self.text_embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]

        tgt_square_mask = create_square_mask(tgt.size(1)).to(src.device)

        #Encoder
        memory = self.encoder(src_embedding.permute(1, 0, 2), src_key_padding_mask=src_mask)

        #Decoder
        decoder_out = self.decoder(tgt_embedding.permute(1, 0, 2), memory, tgt_mask=tgt_square_mask, tgt_key_padding_mask=tgt_mask)
        decoder_out = decoder_out.permute(1, 0, 2)

        #FC output
        output = self.fc(decoder_out)

        return output

    def seq2seq(self, src, src_mask, stop_token, max_length = 500):
        src_embedding = self.text_embedding(src) + self.positional_encoding[:, :src.size(1), :]

        memory = self.encoder(src_embedding.permute(1, 0, 2), src_key_padding_mask=src_mask)
        sequence = src
        stop = False

        while sequence.shape[1] < min(self.max_length, max_length) and not stop:
            tgt_embedding = self.text_embedding(sequence) + self.positional_encoding[:, :sequence.size(1), :]

            tgt_square_mask = create_square_mask(sequence.size(1)).to(src.device)
            dec_output = self.decoder(tgt_embedding.permute(1, 0, 2), memory, tgt_mask=tgt_square_mask)
            dec_output = dec_output.permute(1, 0, 2)

            out = self.fc(dec_output)[:, -1, :]
            predicted = out.argmax(dim=1)
            
            if predicted.item() == stop_token:
                stop = True

            sequence = torch.cat((sequence, predicted.unsqueeze(dim=0)),dim=1)

        return sequence

    def create_square_mask(size):
    	mask = torch.triu(torch.ones(size, size), diagonal=1)
    	mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
    	return mask

def train_step(model, dataloader, criterion, optimizer, device):
    avg_loss = 0
    model.train()
    for batch, (text_data, text_pad_mask) in enumerate(dataloader):
        text_data, text_pad_mask = text_data.to(device), text_pad_mask.to(device)

        #shift data so that the in_text is the initial tokens and that tgt_text is the next predicted token in the sequence
        in_text = text_data[:, :-1]
        in_mask = text_pad_mask[:, :-1]
        tgt_text = text_data[:, 1:]
        tgt_mask = text_pad_mask[:, 1:]


        out = model(in_text, tgt_text, in_mask, tgt_mask)

        outputs = out[:, :].reshape(-1, model.vocab_size)# Reshape to [batch_size * steps, vocab_size]
        targets = tgt_text[:, :].reshape(-1)# Reshape to [batch_size * steps]

        loss = criterion(outputs, targets)
        avg_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    return avg_loss / len(dataloader)

I think this is all that is necessary to try diagnose the issue as I’m 100% sure the tokenizer and data loader is working perfectly as I’ve done a lot of testing on them and don’t want to flood this post with too much code but I can provide the code for them upon request if it helps at all.

If anyone could help me with this problem it would be massively appreciated as this has been something which has stumped me for weeks now.

Thanks for your time :slight_smile:

I was hitting a similar problem. The model learned to predict the last token in the input sequence instead of learning to predict the next token. You need to make sure your model can’t accidentally cheat by looking ahead at the target sequence.

The three things to check:

  1. Are you shifting the target tokens before they input in the model?
  2. Are you calculating the loss with the non-shifted target tokens?
  3. Is masking correct?

So for example if the input is:
“I am a cat”

And the target is the french translation “je suis un chat”

And we are using “<|start|>”, “<|end|>” to be the start/end tokens

Then the model should get an input of:

model("I am a a cat<|end|>", "<|start|>je suis un chat<|end|>")

target = "je suis un chat<|end|>"

If you look at it position by position then:

  1. For the source token of “I” and “<|start|>”, the model should predict “je”
  2. For the source tokens of “I am” and “<|start|> je”, the model should predict “suis”
  3. For the source tokens of “I am a” and “<|start|> je suis”, the model should predict “un”

Outdated, but I encountered the same issue and found this thread so I want to share my two steps that solved the problem for me:

  1. Make sure token shifting is correct for autoregressive generation and verify there is a Start of Sequence token
  2. Too much norming reduces the gradients to 0, hindering the model to learn. Instead the model might just output the last token

Thanks for the replies everyone I appreciate it, and apologies for the late response but I have solved the problem now.

The problem was that I was using an Encoder to Decoder model, when for the task I was trying to do only needed a Decoder (I think a TransformerDecoder would be the preferred Module but I accidentally used a TransformerEncoder in the below solution when I made the quick change and started training to test it out but seemed to work fine anyway), here is the updated solution:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()

        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1), :]
        return x


class TextEmbedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, padding_index: int):
        super(TextEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=padding_index)

    def forward(self, x):
        return self.embedding(x)

class TextTransformer(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int = 512, nhead: int = 8, num_layers: int = 6, max_length: int = 5000, padding_index: int = 0):
        super(TextTransformer, self).__init__()
        self.vocab_size = vocab_size
        self.max_length = max_length

        self.text_embedding = TextEmbedding(vocab_size, embed_dim, padding_index)
        self.positional_encoding = PositionalEncoding(embed_dim, max_len=max_length)

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=2048)
        self.decoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers)

        self.fc = nn.Sequential(
            nn.Linear(embed_dim, vocab_size)
        )

    def forward(self, seq, seq_mask) -> torch.Tensor:
        #Embedding + Positional Encoding
        embedding = self.positional_encoding(self.text_embedding(seq))
        square_mask = create_square_mask(seq.size(1)).to(seq.device)

        #Decoder
        decoder_out = self.decoder(embedding.permute(1, 0, 2), mask=square_mask, src_key_padding_mask=seq_mask, is_causal=True)
        decoder_out = decoder_out.permute(1, 0, 2)

        #FC output
        output = self.fc(decoder_out)

        return output

    def seq2seq(self, seq: torch.Tensor, seq_mask: torch.Tensor, stop_token: int, max_length: int = 500, top_p: int = 0.5, temperature: float = 1.0):
        sequence = seq
        stop = False

        while sequence.shape[1] < min(self.max_length, max_length) and not stop:
            embedding = self.positional_encoding(self.text_embedding(sequence))

            square_mask = create_square_mask(sequence.size(1)).to(seq.device)
            dec_output = self.decoder(embedding.permute(1, 0, 2), mask=square_mask)
            dec_output = dec_output.permute(1, 0, 2)

            out = self.fc(dec_output)[:, -1, :]
            predicted = token_selection(out, top_p, temperature)
            #predicted = out.argmax(dim=1).view(-1, 1)
            
            if predicted.item() == stop_token:
                stop = True

            sequence = torch.cat((sequence, predicted),dim=1)

        return sequence

def create_square_mask(size: int) -> torch.Tensor:
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, float(0.0))
    return mask

def token_selection(logits: torch.Tensor, top_p: int = 0.5, temperature: float = 1.0):#[1,5000]
    logits = logits / temperature

    probabilities = torch.softmax(logits, dim=-1)
    sorted_probabilities, sorted_indices = torch.sort(probabilities, descending=True)
    cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)

    cutoff_mask = cumulative_probabilities <= top_p
    cutoff_index = max(cutoff_mask.sum().item(), 1)

    top_p_probabilities = sorted_probabilities[:, :cutoff_index]
    top_p_indices = sorted_indices[:, :cutoff_index]

    top_p_probabilities /= top_p_probabilities.sum()

    next_tokens = torch.multinomial(top_p_probabilities, 1)
    selected_indices = torch.gather(top_p_indices, 1, next_tokens)
    return selected_indices