Why transformer model is behaving like this?

loss at epoch 10 is 5.1940178871154785

[1614, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136]
sort a a a a a a a a a a a a a a
loss at epoch 11 is 5.237536907196045

[2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160, 2160]
i i i i i i i i i i i i i i i i i
loss at epoch 12 is 5.147286415100098

[913, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 199, 393, 393, 393]
are to to to to to to to to to to to you to to to
loss at epoch 13 is 4.996100902557373

[1033, 740, 393, 740, 740, 740, 740, 740, 740, 740, 393, 740, 740, 136, 740, 740]
now the to the the the the the the the to the the a the the
loss at epoch 14 is 5.0865559577941895

[2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212]
maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe
loss at epoch 15 is 4.944348335266113

[1858, 393, 199, 393, 199, 393, 199, 393, 199, 199, 393, 199, 199, 199, 199]
he to you to you to you to you you to you you you you
loss at epoch 16 is 4.972095012664795

[1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858]
he he he he he he he he he he he he he he he he he
loss at epoch 17 is 4.938692569732666

[2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 136]
okay okay okay okay okay okay okay okay okay okay okay okay okay okay okay a
loss at epoch 18 is 4.899395942687988

model is

import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
import re
import pandas
import math
import random
import matplotlib.pyplot as mat

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
print(device)

dataset = pandas.read_csv(r"Conversation.csv", usecols = [“question”,“answer”])

question = dataset[“question”].values
answer = dataset[“answer”].values

def remove_punctuation(sentence):
sentence = re.sub(r"[^\w\s]", “”, str(sentence)).lower()
sentence = sentence.split()
return sentence

question = list(map(remove_punctuation, question))
answer = list(map(remove_punctuation, answer))

vocab = [“”, “”, “”, “”] + list(set([word for sentence in question for word in sentence] + [word for sentence in answer for word in sentence]))

word_to_index = {word : index for index,word in enumerate(vocab)}
index_to_word = {index : word for word,index in word_to_index.items()}

pad_index = word_to_index.get(“”,0)
unknown_index = word_to_index.get(“”,3)
end_index = word_to_index.get(“”,2)

class custom_dataset(Dataset):
def init(self,question,answer):
self.question = question
self.answer = answer

def __len__(self):
    return len(self.question)

def __getitem__(self,index):
    question =  [word_to_index.get(word) for word in self.question[index]] 
    
    answer = [word_to_index.get(word) for word in self.answer[index]] 
    
    return torch.tensor(question, dtype = torch.long).to(device),torch.tensor(answer, dtype = torch.long).to(device)

def add_padding(batch):
questions,answers = zip(*batch)

max_question_length = max(len(question) for question in questions) 
max_answer_length = max(len(answer) for answer in answers)

max_len = max(max_question_length,max_answer_length) 

padded_questions = []
padded_answers = []

for question in questions:
    extend_till = max_len - len(question)
    padded_question = torch.cat([question, torch.tensor([pad_index] * extend_till).to(device)])
    padded_questions.append(padded_question)
    
for answer in answers:
    extend_till = max_len - len(answer)
    padded_answer = torch.cat([answer, torch.tensor([pad_index] * extend_till).to(device)])
    padded_answers.append(padded_answer)
    
questions = torch.stack(padded_questions).long().to(device)
answers = torch.stack(padded_answers).long().to(device)

return questions,answers

def create_key_padding_mask(tensor,pad_index):
mask = (tensor == pad_index).to(device)
return mask

def look_ahead_mask(sentence_length):
mask = torch.triu(torch.ones(sentence_length,sentence_length) * float(“-inf”), diagonal = 1).bool().to(device)
return mask

max_len = max(dataset[“question”].apply(len).max(), dataset[“answer”].apply(len).max())

class positional_encoding(nn.Module):
def init(self,embedding_dimension,max_len):
super().init()

    position = torch.arange(0,max_len).unsqueeze(1).to(device)
    
    self.positional_encodings = torch.empty(max_len,embedding_dimension).to(device)
    
    div_term = torch.exp(torch.arange(0,embedding_dimension,2) * (-math.log(10000)/embedding_dimension)).to(device)
    div_term = div_term.unsqueeze(0)
    
    self.positional_encodings[:,0::2] = torch.sin(position * div_term).to(device)
    self.positional_encodings[:,1::2] = torch.cos(position * div_term).to(device)
    
    self.register_buffer("positional_encoding", self.positional_encodings)
    
def forward(self,x):
    x = x + self.positional_encodings[:x.size(1),:].to(device)
    return x

class transformer(nn.Module):
def init(self,embedding_dimension,number_of_heads,num_encoder_layers, num_decoder_layers,ff_dimesnion,vocab):
super().init()

    self.embedding_dimension = embedding_dimension
    
    self.embedding = nn.Embedding(len(vocab),embedding_dimension)
    self.positional_encodings = positional_encoding(self.embedding_dimension,max_len)
    
    self.transformer = nn.Transformer(d_model = embedding_dimension, nhead = number_of_heads,num_encoder_layers = num_encoder_layers, num_decoder_layers = num_decoder_layers, dim_feedforward = ff_dimesnion, batch_first = True)
    
    self.output = nn.Linear(self.embedding_dimension,len(vocab))
            
def forward(self,src,target,src_key_padding_mask = None, tgt_key_padding_mask = None, tgt_mask = None, memory_mask = None):
    
    src = self.embedding(src) * math.sqrt(self.embedding_dimension)
    src = self.positional_encodings(src)
    
    target = self.embedding(target) * math.sqrt(self.embedding_dimension)
    target = self.positional_encodings(target)
    
    output = self.transformer(src,target, src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask, tgt_mask = tgt_mask, memory_mask = memory_mask)
    
    output = self.output(output)

    return output
    
def encoder(self,src,src_key_padding_mask = None):
    src = self.embedding(src) * math.sqrt(self.embedding_dimension)
    src = self.positional_encodings(src)
    
    encoder_output = self.transformer.encoder(src,src_key_padding_mask = src_key_padding_mask)
    return encoder_output

def decoder(self,encoder_output,target,memory_mask = None,target_key_padding_mask = None,target_mask = None):
    target = self.embedding(target) * math.sqrt(self.embedding_dimension)
    target = self.positional_encodings(target)
    
    decoder_output = self.transformer.decoder(target,encoder_output, memory_mask = memory_mask, tgt_key_padding_mask = target_key_padding_mask, tgt_mask = target_mask)
    
    output = self.output(decoder_output)
    
    return output

datasett = custom_dataset(question,answer)

training_data = DataLoader(datasett, batch_size = 64, shuffle = True, collate_fn = add_padding)

model = transformer(128,8,3,3,512,vocab).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

loss = nn.CrossEntropyLoss(ignore_index = pad_index)

losses =

for epoch in range(200):
for batch,(question,answer) in enumerate(training_data):

    optimizer.zero_grad()
    
    inputt  = question.to(device)
    target = answer.to(device)
    
    src_key_padding_mask = create_key_padding_mask(inputt,pad_index).to(device)
    target_key_padding_mask = create_key_padding_mask(target,pad_index).to(device)
    
    no_peek_mask = look_ahead_mask(target.size(1)).to(device)

    teacher_forcing_ratio = max(0.2, 1 - (epoch * 0.1))
    
    if random.random() + 1 < teacher_forcing_ratio:
        
        predicted = model.forward(inputt,target, src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = target_key_padding_mask,tgt_mask = no_peek_mask)

        predictedd = torch.argmax(predicted, dim = -1)

        if (batch == 25):
            for sentence in predictedd:
                sentence = sentence.cpu().tolist()
                print(sentence)
                for word in sentence:
                    print(index_to_word.get(word), end = " ")
                break
            print()
            break

    else:
        
        decoder_input = target[:,0].unsqueeze(1)
        
        encoder_output = model.encoder(inputt,src_key_padding_mask = src_key_padding_mask)
        
        predicted = []
        
        for predict in range(target.size(1)):
            decoder_output = model.decoder(encoder_output,decoder_input,target_key_padding_mask = create_key_padding_mask(decoder_input,pad_index).to(device),target_mask = look_ahead_mask(decoder_input.size(1))).to(device)
            
            predicted_word = torch.argmax(decoder_output[:,-1:],dim = -1)
            predicted.append(decoder_output[:,-1:])
            decoder_input = torch.cat([decoder_input,predicted_word],dim = 1).to(device)
            
        predicted = torch.stack(predicted,dim = 1).to(device)

        if (batch == 25):
            for sentence in decoder_input:
                sentence = sentence.cpu().tolist()
                print(sentence)
                for word in sentence:
                    print(index_to_word.get(word),end = " ")
                break
            print()
            break
        
    predicted = predicted.view(-1,predicted.size(-1))
    target = target.view(-1)
        
                        
    losss = loss(predicted,target)
    losss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm = 1.0)
    optimizer.step()
    
if(epoch % 4 == 0):
   losses.append(losss.detach().item())
    
print(f"loss at epoch {epoch} is {losss.detach().item()}")
print()

mat.plot(range(len(losses)),losses)
mat.xlabel(“epochs”)
mat.ylabel(“loss”)
mat.savefig(“loss.png”)

torch.save({“model_state_dict”: model.state_dict(),
“word_to_index” : word_to_index,
“index_to_word” : index_to_word},“transformer.pt”)

with torch.no_grad():
model.eval()
loaded = torch.load(“transformer.pt”)

model.load_state_dict(loaded["model_state_dict"])
word_to_index = loaded["word_to_index"]
index_to_word = loaded["index_to_word"]

question = "hello, may i speak to alice please?"
answer = "<START>"

question = remove_punctuation(question)
answer = remove_punctuation(answer)

inputt = torch.tensor([word_to_index.get(word,unknown_index) for word in question]).unsqueeze(0).to(device)
target = torch.tensor([word_to_index.get(word,unknown_index) for word in answer]).unsqueeze(0).to(device)
    
src_key_padding_mask = create_key_padding_mask(inputt,pad_index).to(device)
target_key_padding_mask = create_key_padding_mask(target,pad_index).to(device)

encoder_output = model.encoder(inputt,src_key_padding_mask)

for predict in range(10):
    target_key_padding_mask = create_key_padding_mask(target,pad_index).to(device)
    no_peek_mask = look_ahead_mask(target.size(1)).to(device)
    
    output = model.decoder(inputt,target,target_key_padding_mask = target_key_padding_mask,target_mask = no_peek_mask)
    
    predicted = output[:,-1].to(device)
    predicted_word = torch.argmax(predicted,dim = 1)
    
    target = torch.cat([target,predicted_word], dim = 1).to(device)
    
    print(predicted_word,end = " ")
    
    if(predicted_word == end_index):
        break

on using and token it only predicts and token

I have a working example for a machine translation use case. Maybe this helps.

2 Likes