I’m working on implementing a Transformer-based chatbot model inspired by the “Attention is All You Need” paper, utilizing PyTorch’s built-in functions and components. However, I’m facing an issue where the model’s loss remains almost the same after each epoch, even though the gradients are changing during training.
Here’s the code for my model:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
torch.set_printoptions(sci_mode=False)
tokenizer = AutoTokenizer.from_pretrained('t5-small')
tokenizer.add_special_tokens({'bos_token': '<s>'})
class TransformerChatbot(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers, num_heads, max_length, dropout):
super(TransformerChatbot, self).__init__()
self.embed_size = embed_size
self.max_length = max_length
self.embedding = nn.Embedding(vocab_size, embed_size)
self.positional_encoding = PositionalEncoding(embed_size, max_length, dropout)
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(embed_size, num_heads, hidden_size, dropout, batch_first=True),
num_layers
)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(embed_size, num_heads, hidden_size, dropout, batch_first=True),
num_layers
)
self.linear = nn.Linear(embed_size, vocab_size)
self._reset_parameters()
def _create_masks(self, input_seq, target_seq, pad_token_id):
input_mask = (input_seq == pad_token_id).float()
target_mask = (target_seq == pad_token_id).float()
tgt_mask = nn.Transformer.generate_square_subsequent_mask(target_seq.size(1)).to(device)
return input_mask, target_mask, tgt_mask
def forward(self, input_seq, target_seq, pad_token_id=None):
if pad_token_id is None:
pad_token_id = tokenizer.pad_token_id
input_mask, target_mask, tgt_mask = self._create_masks(input_seq, target_seq, pad_token_id)
input_embed = self.embedding(input_seq)
input_embed = self.positional_encoding(input_embed)
target_embed = self.embedding(target_seq)
target_embed = self.positional_encoding(target_embed)
encoder_output = self.transformer_encoder(input_embed, src_key_padding_mask=input_mask)
decoder_output = self.transformer_decoder(
target_embed,
encoder_output,
tgt_mask=tgt_mask,
tgt_key_padding_mask=target_mask,
memory_key_padding_mask=input_mask,
)
output_logits = self.linear(decoder_output)
return output_logits
def generate(self, input_ids, max_length=None, temperature=1.0):
if max_length is None:
max_length = self.max_length
self.eval()
generated_seq = torch.ones(1, 1).fill_(tokenizer.bos_token_id).to(torch.int).to(device)
with torch.no_grad():
input_embed = self.embedding(input_ids)
input_embed = self.positional_encoding(input_embed)
encoder_output = self.transformer_encoder(input_embed)
for _ in range(max_length-1):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(generated_seq.size(1)).to(device)
input_embed = self.embedding(generated_seq)
output_embed = self.positional_encoding(input_embed)
decoder_output = self.transformer_decoder(
output_embed,
encoder_output,
tgt_mask=tgt_mask,
)
output_logits = self.linear(decoder_output)
next_token_logits = output_logits[:, -1, :] / temperature
next_token_probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(next_token_probs, 1)
generated_seq = torch.cat([generated_seq, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
generated_tensor = generated_seq
generated_seq = generated_seq.squeeze().tolist()
generated_text = tokenizer.decode(generated_seq, skip_special_tokens=True)
return len(generated_seq), generated_tensor, generated_text
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# Positional Encoding module
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, max_len=512, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, embed_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_size, 2).float() *
-(math.log(10000.0) / embed_size))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
# x: [batch_size, sequence_length, embed_size]
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
Issues Encountered-
-
Loss Plateau: Despite observing changing gradients during training, the loss doesn’t improve significantly after each epoch.
-
Masks Handling: I am unsure if my mask creation is correct, particularly for padding and subsequent tokens.
Questions-
-
Are there any mistakes or inefficiencies in how I’m handling the input and target masks?
-
Could there be an issue with my positional encoding or the initialization of parameters?
-
Are there any common pitfalls with using
nn.TransformerEncoder
andnn.TransformerDecoder
that I might be missing?
This is the link for full source code:
https://www.kaggle.com/code/cutedeadu/transformer-chatbot-model-69e9a1
Any insights, suggestions, or pointers to what might be causing the stagnant loss would be greatly appreciated. Thank you in advance!