Transformer Chatbot Model: Loss Not Improving Despite Gradient Changes

I’m working on implementing a Transformer-based chatbot model inspired by the “Attention is All You Need” paper, utilizing PyTorch’s built-in functions and components. However, I’m facing an issue where the model’s loss remains almost the same after each epoch, even though the gradients are changing during training.

Here’s the code for my model:

import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer

torch.set_printoptions(sci_mode=False)

tokenizer = AutoTokenizer.from_pretrained('t5-small')
tokenizer.add_special_tokens({'bos_token': '<s>'})

class TransformerChatbot(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, num_heads, max_length, dropout):
        super(TransformerChatbot, self).__init__()
 
        self.embed_size = embed_size
        self.max_length = max_length
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, max_length, dropout)
        
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_size, num_heads, hidden_size, dropout, batch_first=True),
            num_layers
        )
        
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(embed_size, num_heads, hidden_size, dropout, batch_first=True),
            num_layers
        )
       
        self.linear = nn.Linear(embed_size, vocab_size)
        self._reset_parameters()
    
    def _create_masks(self, input_seq, target_seq, pad_token_id):
        input_mask = (input_seq == pad_token_id).float()
        target_mask = (target_seq == pad_token_id).float()
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(target_seq.size(1)).to(device)
        
        return input_mask, target_mask, tgt_mask
   
    def forward(self, input_seq, target_seq, pad_token_id=None):
        if pad_token_id is None:
            pad_token_id = tokenizer.pad_token_id
        
        input_mask, target_mask, tgt_mask = self._create_masks(input_seq, target_seq, pad_token_id)
        
        input_embed = self.embedding(input_seq)
        input_embed = self.positional_encoding(input_embed)
        target_embed = self.embedding(target_seq)
        target_embed = self.positional_encoding(target_embed)
        
        encoder_output = self.transformer_encoder(input_embed, src_key_padding_mask=input_mask)
        decoder_output = self.transformer_decoder(
            target_embed, 
            encoder_output,
            tgt_mask=tgt_mask, 
            tgt_key_padding_mask=target_mask, 
            memory_key_padding_mask=input_mask,
        )
        
        output_logits = self.linear(decoder_output)
        
        return output_logits
    
    def generate(self, input_ids, max_length=None, temperature=1.0):
        if max_length is None:
            max_length = self.max_length
        
        self.eval()
        generated_seq = torch.ones(1, 1).fill_(tokenizer.bos_token_id).to(torch.int).to(device)
        
        with torch.no_grad():
            input_embed = self.embedding(input_ids)
            input_embed = self.positional_encoding(input_embed)
            encoder_output = self.transformer_encoder(input_embed)
            
            for _ in range(max_length-1):
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(generated_seq.size(1)).to(device)
                
                input_embed = self.embedding(generated_seq)
                output_embed = self.positional_encoding(input_embed)
                decoder_output = self.transformer_decoder(
                    output_embed,
                    encoder_output,
                    tgt_mask=tgt_mask, 
                )
                
                output_logits = self.linear(decoder_output)
                
                next_token_logits = output_logits[:, -1, :] / temperature
            
                next_token_probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(next_token_probs, 1)
                generated_seq = torch.cat([generated_seq, next_token], dim=1)
                
                if next_token.item() == tokenizer.eos_token_id:
                    break
                
        generated_tensor = generated_seq
        generated_seq = generated_seq.squeeze().tolist()
        generated_text = tokenizer.decode(generated_seq, skip_special_tokens=True)
        
        return len(generated_seq), generated_tensor, generated_text
    
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
# Positional Encoding module
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=512, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, embed_size)
        
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * 
                             -(math.log(10000.0) / embed_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: [batch_size, sequence_length, embed_size]
        x = x + self.pe[:, :x.size(1), :]
        
        return self.dropout(x)

Issues Encountered-

  • Loss Plateau: Despite observing changing gradients during training, the loss doesn’t improve significantly after each epoch.

  • Masks Handling: I am unsure if my mask creation is correct, particularly for padding and subsequent tokens.

Questions-

  • Are there any mistakes or inefficiencies in how I’m handling the input and target masks?

  • Could there be an issue with my positional encoding or the initialization of parameters?

  • Are there any common pitfalls with using nn.TransformerEncoder and nn.TransformerDecoder that I might be missing?

This is the link for full source code:
https://www.kaggle.com/code/cutedeadu/transformer-chatbot-model-69e9a1

Any insights, suggestions, or pointers to what might be causing the stagnant loss would be greatly appreciated. Thank you in advance!