Transformer Decoder always predict same token with CLIP embedding

I’m trying to train a Transformer Decoder to generate textual captions from a CLIP embedding in an autoregressive way. I followed this official Pytorch Tutorial as a base for my project. However, when I train the model the loss decreases, but at inference it always predicts the same token for all the positions in the sequence. I have further debugged my model and during training time it also does the same. It seems that the Positional Encoding or the tgt_mask are not working as expected.

My Transformer Decoder:

class TextDecoder_TRANSFORMER(nn.Module):
    def __init__(self,latent_dim=512, ff_size=1024, num_layers=4, num_heads=4, dropout=0.1, activation="gelu", **kargs):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.ff_size = ff_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout = dropout
        self.activation = activation
                
        self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
        seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
                                                          nhead=self.num_heads,
                                                          dim_feedforward=self.ff_size,
                                                          dropout=self.dropout,
                                                          activation=activation)
        self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
                                                     num_layers=self.num_layers)
        
        self.finallayer = nn.Linear(self.latent_dim, 49408)
        
    def decode(self, tgt, memory, tgt_mask):
        return self.seqTransDecoder(self.sequence_pos_encoder(tgt), memory, tgt_mask)
    
    def forward(self, batch):

        # z is the CLIP embedding
        z, mask, padding_mask = batch["z"], batch["text_mask"], batch["text_padding_mask"]
        
        # Text Captions
        captions = batch["clip_text_embedding"]
        captions = self.sequence_pos_encoder(captions)

        output = self.seqTransDecoder(tgt=captions, 
                                      tgt_mask=mask,
                                      tgt_key_padding_mask=padding_mask,
                                      memory=z)
        
        output = self.finallayer(output)
        batch["output_caption"] = output
        return batch

The Positional Encoding Used:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)

And the code used to encode the textual information feed to the model

    # Extracted from: https://pytorch.org/tutorials/beginner/translation_transformer.html
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones((sz, sz), device=self.device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def create_mask(self, tokens):
        tk_seq_len = tokens.shape[0]
        tk_mask = self.generate_square_subsequent_mask(tk_seq_len)
        tk_padding_mask = (tokens == 0).transpose(0, 1)
        return tk_mask, tk_padding_mask

    def encode_text(self, batch):
        captions = batch["clip_text"]

        captions_tensor = []
        for caption in captions:
            caption = clip.tokenize(caption).to(self.device)
            captions_tensor.append(caption)
        
        captions_tensor = torch.cat(captions_tensor, dim=0).to(self.device)
        captions_tensor = captions_tensor.permute(1,0)
        batch["text_tokens"] = captions_tensor

        mask, padding_mask = self.create_mask(captions_tensor[:-1,:])
        batch["text_mask"] = mask
        batch["text_padding_mask"] = padding_mask

        caption_embedding = self.clip_model.token_embedding(captions_tensor[:-1,:]).type(self.clip_model.dtype)  
        caption_embedding.to(self.device)
        batch["clip_text_embedding"] = caption_embedding

        return batch
    

    # function to generate output sequence using greedy algorithm
    def greedy_decode(self,text, memory, max_len):
        # The tokenizer is the SimpleTokenizer from the CLIP repository
        START_IDX = self.tokenizer.encoder['<|startoftext|>']
        EOS_IDX = self.tokenizer.encoder['<|endoftext|>']

        ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(self.device)
        for i in range(max_len-1):
            memory = memory.to(self.device)
            tgt_mask = (self.generate_square_subsequent_mask(ys.size(0))
                        .type(torch.bool)).to(self.device)

            token_embs = self.clip_model.token_embedding(ys)
            
            out = self.textDecoder.decode(token_embs, memory, tgt_mask)
            out = out.transpose(0, 1)
            prob = self.textDecoder.finallayer(out[0,:,:])
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word[-1].item()
            
            ys = torch.cat([ys, torch.ones(1, 1).type(torch.long).fill_(next_word).to(self.device)], dim=0)

            if next_word == EOS_IDX:
                break
        
        return ys

I’ve been days trying to figure out what is wrong, but I haven’t been able to find it. I would really appreciate your help and wisdom.