Transformer decoder not learning

I was trying to use a nn.TransformerDecoder to obtain text generation results. But the model remains not trained (loss not decreasing, produce only padding tokens). The code is as below:

import torch
import torch.nn as nn
import math

import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x +[:x.size(0), :]
class TransformerDecoder(nn.Module):
    def __init__(self, vocabs, d_model, nhead, num_layers,dim_feedforward=2048, dropout=0.1,device = "cpu"):
        super(TransformerDecoder, self).__init__()
        self.d_model = d_model
        self.vocabs = vocabs
        self.vocab_size = len(vocabs)
        self.device = device
        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.word_embedding = nn.Embedding( self.vocab_size , self.d_model)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout),
        self.fc_out = nn.Linear(d_model, self.vocab_size)
        self.start_token = self.vocabs.word2idx['<start>'] # Default
        self.end_token = self.vocabs.word2idx['<end>']

    memory: input embedding : (batch_size, embedding_size)
    tgt: tgt_size: (batch_size, max_len)
    def forward(self, memory, tgt):
        tgt = tgt.permute(1,0)
        memory = memory.permute(1,0)

        seq_len = tgt.shape[0]
        mem_len = memory.shape[0]
        memory = memory.unsqueeze(0).permute(0,2,1)

        tgt_mask = nn.Transformer().generate_square_subsequent_mask(seq_len).to(self.device)
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)

        output = self.transformer_decoder(tgt, memory, tgt_mask = tgt_mask)

        output = self.fc_out(output)

        return output

    memory: input embedding : (batch_size, embedding_size)
    def predict(self,memory, max_len = 90):
        batch_size = memory.shape[0]
        outputs_tensor = torch.zeros((batch_size, max_len, self.vocab_size)).to(self.device)

        generated_words = torch.LongTensor([self.start_token]*batch_size).unsqueeze(1).to(self.device)
        outputs_tensor[:,0,self.start_token] = 1
        for i in range(max_len-1):
            with torch.no_grad():
                output = self.forward(memory, generated_words)
            outputs_tensor[:,i+1,:] = output[-1,:,:]
            next_word = torch.argmax(output[-1,:,:],dim = 1).unsqueeze(1)
            generated_words =, next_word), dim=1)

            if (next_word == self.end_token).all():
        return generated_words,outputs_tensor

Would there be anything wrong with the forward or predict function? I saw some post that learning rate too high might disturb training. I set it to 0.0001 but still receive no loss drop.
Thanks in advance for all the help. Would really really appreciate it.