Model predicted almost correct sentences at the time of training but is only predicting <START> token at the time of test

loss was like
loss at epoch 61 is 0.01835295371711254
[2, 3128, 4624, 4455, 3830, 3524, 4574, 3387, 2814, 2971, 4815, 3, 2729, 3551, 2882, 2893, 4405]
<< START >> me too lets forget about germs and focus on food

loss at epoch 62 is 0.010977335274219513
[2, 3876, 4126, 2729, 4669, 2776, 4098, 3, 2882, 4844, 3182, 2893, 4687, 3182, 4405]
<< START >> then you should watch the rerun << END >> go is end

loss at epoch 63 is 0.010273752734065056
[2, 3324, 4844, 3987, 3, 4163, 4163, 3551, 4405, 4405, 4084, 2729, 4849, 2729, 4084, 4084, 2729]
<< START >> why is that << END >> live live where can can guy

loss at epoch 64 is 0.010734939947724342
[2, 3987, 4844, 4319, 2776, 3466, 4844, 3083, 4139, 3, 4405, 4405, 4405, 4087, 4405, 4687]
<< START >> that is true the weather is constantly changing << END >> can

model was

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
import pandas
import numpy
import math
import os
import re
import matplotlib.pyplot as mat
from torch.cuda.amp import autocast,GradScaler
import random

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
print(device)

dataset = pandas.read_csv(r"Conversation.csv",usecols = [“question”,“answer”])

dataset_str = dataset.astype(str)
length = dataset_str.map(len)
dataset[“length”] = length.sum(axis = 1)

dataset = dataset.sort_values(by = “length”)

datasett = pandas.DataFrame()

datasett[“question”] = dataset[“question”].astype(str)
datasett[“answer”] = dataset[“answer”].astype(str)

def remove_punctuation(sentence):
sentence = re.sub(r"[^\w\s]“,”",sentence)
return sentence

datasett[“question”] = datasett[“question”].map(remove_punctuation)
datasett[“answer”] = datasett[“answer”].map(remove_punctuation)

vocab = [“”, “”, “”, “”] + list(set(word for sentence in datasett[“question”] for word in sentence.split())) + list(set(word for sentence in datasett[“answer”] for word in sentence.split()))

word_to_index = {word:index for index,word in enumerate(vocab)}
index_to_word = {index:word for word,index in word_to_index.items()}

training_data =

for i in range(0,len(dataset)):
data = datasett[“question”].iloc[i]
target = datasett[“answer”].iloc[i]
training_data.append((data,target))

class custom_dataset(Dataset):
def init(self,training_data,word_to_index,index_to_word):
self.question,self.answer = zip(*training_data)
self.word_to_index = word_to_index
self.index_to_word = index_to_word

def __len__(self):
    return len(self.question)

def __getitem__(self,index):        
    question = re.sub(r"[^\w\s]", "" ,self.question[index]).lower()
    answer = re.sub(r"[^\w\s]", "" ,self.answer[index]).lower()
    
    sentence =  torch.tensor([self.word_to_index.get("<START>")] + [self.word_to_index.get(word,0) for word in (question.split())] + [self.word_to_index.get("<END>")])
            
    target_sentence = torch.tensor([self.word_to_index.get("<START>")] + [self.word_to_index.get(word,0) for word in (answer.split())] + [self.word_to_index.get("<END>")])
                    
    return (sentence, target_sentence)

def collate_batch(batch):
sentences,target = zip(*batch)

padded_sentences = [x]
padded_target_sentences = []

max_sentence = len(max(sentences,key = len))
max_target = len(max(target,key = len))

max_length = max(max_sentence,max_target)

for sentence in sentences:
    padding = torch.zeros(max_length - len(sentence))
    sentence = torch.cat([sentence,padding])
    padded_sentences.append(sentence)

for target_sentence in target:
    padding = torch.zeros(max_length - len(target_sentence))
    target_sentence = torch.cat([target_sentence,padding])
    padded_target_sentences.append(target_sentence)  
    
sentences = torch.stack(padded_sentences).long()
target = torch.stack(padded_target_sentences).long()
            
return (sentences,target)

dataset = custom_dataset(training_data,word_to_index,index_to_word)

training_data = DataLoader(dataset, batch_size = 32, shuffle = True, collate_fn = collate_batch)

class Transformer(nn.Module):
def init(self):
super().init()

    self.embeddings = nn.Embedding(len(vocab),128)
    self.transformer = nn.Transformer(128,4,2,2,512,batch_first = True,layer_norm_eps = 0.0001)
    self.out = nn.Linear(128,len(vocab))
    
def generate_mask(self,sentence_length,):
    mask = torch.tril(torch.ones(sentence_length,sentence_length),diagonal = 0)
    return mask.masked_fill(mask == 0,-1e9)

def create_key_padding_mask(self,sequence):
     mask = (sequence != 0)
     mask = mask.float()
     return mask.masked_fill(mask == 0,-1e9)
    
def forward(self,sentence,target,src_key_padding_mask = None, target_key_padding_mask = None, no_peek_mask = None):
    sentence = self.embeddings(sentence.long())
    target = self.embeddings(target.long())
    
    output =  self.transformer(sentence,target,src_key_padding_mask = src_key_padding_mask,tgt_key_padding_mask = target_key_padding_mask, tgt_mask = no_peek_mask)
    
    output = self.out(output)
    
    return output  

Model = Transformer().to(device)

pad_index = dataset.word_to_index.get(“”)

loss = nn.CrossEntropyLoss(ignore_index = pad_index)

losses =

optimizer = torch.optim.Adam(Model.parameters(), lr = 0.0001)
scaler = GradScaler()

for epoch in range(100):
Model.train()
for batch,(sentence,target) in enumerate(training_data):

    sentence = sentence.to(device)
    target = target.to(device)

    torch.autograd.set_detect_anomaly(True)
    optimizer.zero_grad()

    batch_size = sentence.size(0)
    sentence_length = sentence.size(1) 
    target_length = target.size(1)
    
    end_token_index = dataset.word_to_index.get("<END>")
    
    with autocast():        
         src_key_padding_mask = Model.create_key_padding_mask(sentence).to(device)
         target_key_padding_mask = Model.create_key_padding_mask(target).to(device)
         target_mask = Model.generate_mask(target_length).to(device)

         predicted = Model.forward(sentence,target,src_key_padding_mask,target_key_padding_mask,no_peek_mask = target_mask)
         
         predictedd = predicted
         
         end_mask = (target != pad_index)
         end_mask = end_mask.unsqueeze(-1).float()
                                                         
         predicted = predicted * end_mask
         
         predicted = predicted.view(-1,predicted.size(-1))
         target = target.view(-1)
 
         losss = loss(predicted,target)
 
         torch.nn.utils.clip_grad_norm_(Model.parameters(),max_norm = 0.5)
         
         scaler.scale(losss).backward()
         scaler.step(optimizer)
         scaler.update()
    
    if(losss.item() < 0.003):
        break
    
losses.append(losss.detach().item())
print(f"loss at epoch {epoch} is {losss}")

Model.eval()
with torch.no_grad():
    predictedd = torch.argmax(predictedd,dim = -1).cpu().tolist()
    for sentence in predictedd:
        i = 0
        print(sentence)
        for word in sentence:
            print(dataset.index_to_word.get(word), end = " ")
            i += 1
            if(i>10):
                break
        print("\n")
        break
        
if(losss.item() < 0.1):
    break

mat.plot(range(len(losses)),losses)
mat.xlabel(“epochs”)
mat.ylabel(“loss”)
mat.savefig(“loss.jpg”)

torch.save({“model_state_dict” : Model.state_dict(),
“word_to_index” : dataset.word_to_index,
“index_to_word” : dataset.index_to_word,
},“transformer.pt”)

torch.cuda.empty_cache()
loaded = torch.load(“transformer.pt”)

with torch.no_grad():
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
print(device)

predicted = "<START>"
last_word = ["<START>"]

Model.load_state_dict(loaded["model_state_dict"])
Model = Model.to(device)

for parameter in Model.parameters():
    parameter.requires_grad = False
    
word_to_index = loaded["word_to_index"]
index_to_word = loaded["index_to_word"] 

Model.eval()

torch.cuda.empty_cache()

i = 0
while (i != 10):          
       text = "hi, how are you doing?"
       text = re.sub(r"[^\w\s]","",text).lower()
       
       target = predicted
       
       target_mask = Model.generate_mask(len(target.split())).to(device)
       
       text = torch.tensor([word_to_index.get(word, word_to_index.get("<UNK>")) for word in text.split()]).to(device)
       output = torch.tensor([word_to_index.get(word) for word in target.split()]).to(device)
   
       text = text.unsqueeze(0).long().to(device)
       output = output.unsqueeze(0).long().to(device)
   
       predictedd = torch.argmax(Model.forward(text,target = output,no_peek_mask = target_mask),dim = -1)
                  
       if (predictedd.size(1) > 1):
           predictedd = predictedd[0,-1].tolist()   
           predicted_list = index_to_word.get(predictedd)
           last_word.append(predicted_list)
           predicted = " ".join(last_word)
           print(predicted)
           torch.cuda.empty_cache()
           i += 1 
       else:
           predictedd = predictedd.squeeze().tolist()
           print(predictedd)
           last_word.append(index_to_word.get(predictedd))
           predicted = " ".join(last_word)
           print(predicted)
           i += 1
           torch.cuda.empty_cache()
        
print(predicted)

torch.cuda.empty_cache()