Why transformer model is predicting only one random word repetatively in every iteration

can anyone please explain why on using START and END token it only predicts START and END token

on using teacher forcing the model predicts correct sentences but then it can’t predict correct sentences at the time of inference as it doesn’t have the target and on using reduced teacher forcing it predicts correct sentences when teacher forcing is used and sentences like “START START START END END” if START and END tokens are used and sentences like car car car car if START and END tokens are not used and does not learn anything

the loss is like this :

loss at epoch 10 is 5.1940178871154785

\[1614, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136\]
sort a a a a a a a a a a a a a a
loss at epoch 11 is 5.237536907196045

\[913, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 393, 199, 393, 393, 393\]
are to to to to to to to to to to to you to to to
loss at epoch 13 is 4.996100902557373

\[1033, 740, 393, 740, 740, 740, 740, 740, 740, 740, 393, 740, 740, 136, 740, 740\]
now the to the the the the the the the to the the a the the
loss at epoch 14 is 5.0865559577941895

\[2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212, 2212\]
maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe maybe
loss at epoch 15 is 4.944348335266113

\[1858, 393, 199, 393, 199, 393, 199, 393, 199, 199, 393, 199, 199, 199, 199\]
he to you to you to you to you you to you you you you
loss at epoch 16 is 4.972095012664795

\[1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858, 1858\]
he he he he he he he he he he he he he he he he he
loss at epoch 17 is 4.938692569732666

\[2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 2402, 136\]
okay okay okay okay okay okay okay okay okay okay okay okay okay okay okay a
loss at epoch 18 is 4.899395942687988

model is
class transformer(nn.Module):
    def __init__(self,embedding_dimension,number_of_heads,num_encoder_layers, num_decoder_layers,ff_dimesnion,vocab):
        super().__init__()
        
        self.embedding_dimension = embedding_dimension
        
        self.embedding = nn.Embedding(len(vocab),embedding_dimension)
        self.positional_encodings = positional_encoding(self.embedding_dimension,max_len)
        
        self.transformer = nn.Transformer(d_model = embedding_dimension, nhead = number_of_heads,num_encoder_layers = num_encoder_layers, num_decoder_layers = num_decoder_layers, dim_feedforward = ff_dimesnion, batch_first = True)
        
        self.output = nn.Linear(self.embedding_dimension,len(vocab))
                
    def forward(self,src,target,src_key_padding_mask = None, tgt_key_padding_mask = None, tgt_mask = None, memory_mask = None):
        
        src = self.embedding(src) * math.sqrt(self.embedding_dimension)
        src = self.positional_encodings(src)
        
        target = self.embedding(target) * math.sqrt(self.embedding_dimension)
        target = self.positional_encodings(target)
        
        output = self.transformer(src,target, src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask, tgt_mask = tgt_mask, memory_mask = memory_mask)
        
        output = self.output(output)

        return output
        
    def encoder(self,src,src_key_padding_mask = None):
        src = self.embedding(src) * math.sqrt(self.embedding_dimension)
        src = self.positional_encodings(src)
        
        encoder_output = self.transformer.encoder(src,src_key_padding_mask = src_key_padding_mask)
        return encoder_output
    
    def decoder(self,encoder_output,target,memory_mask = None,target_key_padding_mask = None,target_mask = None):
        target = self.embedding(target) * math.sqrt(self.embedding_dimension)
        target = self.positional_encodings(target)
        
        decoder_output = self.transformer.decoder(target,encoder_output, memory_mask = memory_mask, tgt_key_padding_mask = target_key_padding_mask, tgt_mask = target_mask)
        
        output = self.output(decoder_output)
        
        return output
    
datasett = custom_dataset(question,answer)
    
training_data = DataLoader(datasett, batch_size = 64, shuffle = True, collate_fn = add_padding)

model = transformer(128,8,3,3,512,vocab).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.001) 

loss = nn.CrossEntropyLoss(ignore_index = pad_index)

losses = []

for epoch in range(200):
    for batch,(question,answer) in enumerate(training_data):
        
        optimizer.zero_grad()
        
        inputt  = question.to(device)
        target = answer.to(device)
        
        src_key_padding_mask = create_key_padding_mask(inputt,pad_index).to(device)
        target_key_padding_mask = create_key_padding_mask(target,pad_index).to(device)
        
        no_peek_mask = look_ahead_mask(target.size(1)).to(device)

        teacher_forcing_ratio = max(0.2, 1 - (epoch * 0.1))
        
        if random.random() + 1 < teacher_forcing_ratio:
            
            predicted = model.forward(inputt,target, src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = target_key_padding_mask,tgt_mask = no_peek_mask)


        else:
            
            decoder_input = target[:,0].unsqueeze(1)
            
            encoder_output = model.encoder(inputt,src_key_padding_mask = src_key_padding_mask)
            
            predicted = []
            
            for predict in range(target.size(1)):
                decoder_output = model.decoder(encoder_output,decoder_input,target_key_padding_mask = create_key_padding_mask(decoder_input,pad_index).to(device),target_mask = look_ahead_mask(decoder_input.size(1))).to(device)
                
                predicted_word = torch.argmax(decoder_output[:,-1:],dim = -1)
                predicted.append(decoder_output[:,-1:])
                decoder_input = torch.cat([decoder_input,predicted_word],dim = 1).to(device)
                
            predicted = torch.stack(predicted,dim = 1).to(device)

one more thing that on decreasing the learning rate to 0.00001 and using batch_size 64 it started to predict different words for about 4-5 epochs and then stops learning and also till 4-5 epochs it learns very slowly as learning rate is very low

and the dataset consists of about 3000 small conversational sentences like:
hi how are you doing
hi I’m doing good