Problem with conditioning transformer

I implemented a transformer decoder and initially trained it on text sequences, achieving satisfactory performance. Subsequently, I aimed to condition the model on floating-point values. To accomplish this, I converted the floating-point values into embeddings using linear layers, which were then passed to the decoder as memory inputs. However, upon training with this setup, I observed an unusually low loss from the outset, leading me to suspect overfitting.

Despite my concerns, when I sampled from the model by providing a floating-point value and a start token to the decoder, the output sequences were unexpectedly repetitive, consisting of merely two alternating letters. This outcome suggested a problem beyond simple overfitting, as overfitting would typically result in the model reproducing identical sequences rather than generating repetitive patterns.

To diagnose the issue, I sampled outputs every 100 batches during training, and these sequences appeared to be normal, indicating that the training process itself may not be fundamentally flawed.

Given these observations, what do you think could be the problem with my approach?

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class ConditionalTransformerModel(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layers, max_length, property_dim, property_hidden_dim, dropout=0.1):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, n_embd)
        self.n_embd = n_embd
        self.property_processor = nn.Sequential(
            nn.Linear(property_dim, property_hidden_dim),
            nn.ReLU(),
            nn.Linear(property_hidden_dim, n_embd)
        )

        decoder_layer = nn.TransformerDecoderLayer(d_model=n_embd, nhead=n_head, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.pos_encoder = PositionalEncoding(n_embd, dropout, max_length)
        self.generator = nn.Linear(n_embd, vocab_size)
        self.max_length = max_length

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, property):

        batch_size, seq_length = src.size(0), src.size(1)
        src = self.embed(src) * math.sqrt(self.n_embd)
        tgt = self.embed(tgt) * math.sqrt(self.n_embd)
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        property_encoded = self.property_processor(property)
        property_encoded = property_encoded.unsqueeze(1).repeat(1, seq_length, 1)
        memory = property_encoded.transpose(0, 1)

        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        output = self.transformer_decoder(tgt.transpose(0, 1), memory, tgt_mask=tgt_mask)
        output = output.transpose(0, 1)
        return self.generator(output)


# Define your model
model = ConditionalTransformerModel(vocab_size=vocab.vocab_size, n_embd=64, n_head=4, n_layers=4, max_length=512, property_dim=1, property_hidden_dim=128).to('cuda')

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

def sample_predictions(output, vocab, k=5):
    """Sample some predictions from the model output"""
    # Get the most probable next tokens (top-k)
    _, topk_indices = torch.topk(output, k, dim=-1)
    sampled_indices = topk_indices[:, :, 0]  # For simplicity, use the top-1 prediction
    sampled_tokens = [vocab.decode(indices.tolist()) for indices in sampled_indices]
    return sampled_tokens

for epoch in range(10):  # Let's still assume 10 epochs for simplicity
    model.train()  # Set the model to training mode
    total_loss = 0  # To accumulate loss over the epoch
    
    for i, (src_batch, tgt_batch, properties_batch) in enumerate(data_loader):
        src_batch, tgt_batch, properties_batch = src_batch.to('cuda'), tgt_batch.to('cuda'), properties_batch.to('cuda')
        
        optimizer.zero_grad()  # Zero the gradients
        
        # Adjust the model's forward pass call to match the updated signature.
        # Now src is used as input to the encoder and tgt is used as input to the decoder
        # Note: You need to ensure that your model's forward method accepts these arguments correctly
        output = model(src_batch, tgt_batch, properties_batch)  # src_batch is now explicitly separated from tgt_batch
        
        # Calculate loss. Note that we need to adjust the targets to match the output dimensions.
        # Since output is likely of shape [batch_size, seq_len, vocab_size] and tgt_batch is [batch_size, seq_len],
        # we transpose output to [batch_size, vocab_size, seq_len] before passing it along with tgt_batch to the criterion.
        loss = criterion(output.transpose(1, 2), tgt_batch)  # Adjust loss calculation to use tgt_batch
        
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update model parameters
        total_loss += loss.item()  # Accumulate the loss
    
        if i % 100 == 0:  # Sample predictions every 100 batches
            sampled_tokens = sample_predictions(output, vocab)
            print(f"Batch {i}: Sample Predictions: {sampled_tokens[:5]}")  # Print sample predictions from the first 5 sequences

    # Print average loss for the epoch
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(data_loader)}")

this is the code I am using any comments?