I am working on a Transformer-based chatbot model using PyTorch. My loss function decreases with each epoch, but the model’s predictions are not improving and are far from the targeted output. Here is a simplified version of my code:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
torch.set_printoptions(sci_mode=False)
tokenizer = AutoTokenizer.from_pretrained('t5-small')
tokenizer.add_special_tokens({'bos_token': '<s>'})
class TransformerChatbot(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_layers, num_heads, max_length, dropout):
super(TransformerChatbot, self).__init__()
self.embed_size = embed_size
self.max_length = max_length
self.embedding = nn.Embedding(vocab_size, embed_size)
self.positional_encoding = PositionalEncoding(embed_size, max_length, dropout)
self.transformer_encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(embed_size, num_heads, hidden_size, dropout, batch_first=True),
num_layers
)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(embed_size, num_heads, hidden_size, dropout, batch_first=True),
num_layers
)
self.linear = nn.Linear(embed_size, vocab_size)
self._reset_parameters()
def _create_masks(self, input_seq, target_seq, pad_token_id):
input_mask = (input_seq == pad_token_id)
target_mask = (target_seq == pad_token_id)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(target_seq.size(1)).bool().to(device)
return input_mask, target_mask, tgt_mask
def forward(self, input_seq, target_seq, pad_token_id=None):
if pad_token_id is None:
pad_token_id = tokenizer.pad_token_id
input_mask, target_mask, tgt_mask = self._create_masks(input_seq, target_seq, pad_token_id)
input_embed = self.embedding(input_seq)
input_embed = self.positional_encoding(input_embed)
target_embed = self.embedding(target_seq)
target_embed = self.positional_encoding(target_embed)
encoder_output = self.transformer_encoder(input_embed, src_key_padding_mask=input_mask)
decoder_output = self.transformer_decoder(
target_embed,
encoder_output,
tgt_mask=tgt_mask,
tgt_key_padding_mask=target_mask,
memory_key_padding_mask=input_mask,
)
output_logits = self.linear(decoder_output)
return output_logits
def generate(self, input_ids, temperature=1.0, max_length=None):
if max_length is None:
max_length = self.max_length
self.eval()
generated_seq = torch.ones(input_ids.size(0), 1).fill_(tokenizer.bos_token_id).int().to(device)
finished_sequences = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device)
input_mask = (input_ids == tokenizer.pad_token_id)
with torch.no_grad():
input_embed = self.embedding(input_ids)
input_embed = self.positional_encoding(input_embed)
encoder_output = self.transformer_encoder(input_embed, src_key_padding_mask=input_mask)
for _ in range(max_length-1):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(generated_seq.size(1)).bool().to(device)
input_embed = self.embedding(generated_seq)
output_embed = self.positional_encoding(input_embed)
decoder_output = self.transformer_decoder(
output_embed,
encoder_output,
tgt_mask=tgt_mask,
memory_key_padding_mask=input_mask
)
output_logits = self.linear(decoder_output)
next_token_logits = output_logits[:, -1, :] / temperature
next_token_probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(next_token_probs, 1)
next_token = next_token.masked_fill(finished_sequences.view(-1, 1), tokenizer.pad_token_id)
generated_seq = torch.cat([generated_seq, next_token], dim=1)
finished_sequences = finished_sequences | (next_token.view(-1) == tokenizer.eos_token_id)
if finished_sequences.all():
break
generated_tensor = generated_seq
generated_seq_list = [seq.tolist() for seq in generated_seq]
generated_texts = [tokenizer.decode(seq, skip_special_tokens=True) for seq in generated_seq_list]
return [len(seq) for seq in generated_seq_list], generated_tensor, generated_texts
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# Positional Encoding module
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, max_len=512, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, embed_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_size, 2).float() *
-(math.log(10000.0) / embed_size))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
# x: [batch_size, sequence_length, embed_size]
x = x + self.pe[:, :x.size(1), :].to(device)
return self.dropout(x)
The loss function is cross-entropy and the model seems to train properly as the loss decreases over epochs. However, the model’s predictions do not match the target sequences, and the output doesn’t seem to improve even after multiple epochs of training. I have tried adjusting learning rates, different initialization methods, and adding more layers, but none of these changes made a significant difference. Here is my full source code https://www.kaggle.com/code/cutedeadu/transformer-chatbot-model-69e9a1
Any suggestions for debugging or improving the model’s performance would be greatly appreciated.