Transformer Loss remains same

I am trying to implement ImageBind model but a smaller version of it (consisting of text, image and video modalities only)

Although the implementation looks solid,
Source: GitHub - facebookresearch/ImageBind: ImageBind One Embedding Space to Bind Them All

the loss while training is very constant. ie the loss stays at Loge(Batch size)
if batch size is 3, the loss stays at 1.098

here is a snippet:

import torch
from torch.nn.functional import normalize, cross_entropy
from tqdm import tqdm

DEVICE = "cpu"

def contrastive_loss(x, y, temperature=0.07):
    """
    InfoNCE-style contrastive loss between x and y.
    """
    logits = (x @ y.T) / temperature
    labels = torch.arange(x.size(0)).to(x.device)
    loss_i2t = cross_entropy(logits, labels)
    loss_t2i = cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-3, weight_decay=0.01)
loss_fn = nn.MSELoss()

# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    step = 0

    pbar = tqdm(image_dl, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for image, text in pbar:
        # Move data to device
        image = image.to(DEVICE)
        text = text.to(DEVICE)

        # Forward pass
        inputs = {"vision": image, "text": text}
        embeddings = model(inputs)

        # Normalize embeddings
        img_emb = normalize(embeddings["vision"], dim=-1)
        txt_emb = normalize(embeddings["text"], dim=-1)

        # Compute contrastive loss
        loss = contrastive_loss(img_emb, txt_emb)
        print(loss)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        total_loss += loss.item()
        step += 1
        pbar.set_postfix(loss=total_loss / step)

    print(f"Epoch {epoch+1}: Avg Loss = {total_loss / step:.4f}")

please view my implementation: GitHub - Mafaz03/ImageBind
extra points :

  1. requires_grad is true for all
  2. param.grad.abs().mean() is almost 0 for all
  3. I have also tried all initialization such as Xavier and that didnt fix it.

Thank you.