Trainable parameters become Nan after optimizer.step() at the end of first iteration

Hi, I’m trying to train a custom projection module to align the image and text embeddings onto a common embedding space. There is a MLP module for learning positional embeddings of the coordinates of the image patches. The dataset consists of large dimension images, so are split into patches, processed along with the corresponding coordinates to learn image-level embeddings.

At the end of the first iteration, immediately after optimizer.step(), the model parameters become Nan.

I have attached the code snippet below. Your support is highly appreciated, Thanks!

class CustomVLM(nn.Module):
    def __init__(self, coord_mlp, projection_module, device="cuda"):
        super(CustomVLM, self).__init__()
        self.device = device
        
        self.coord_mlp = coord_mlp
        
        self.projection_module = projection_module
        checkpoint = torch.load("mm_projector.bin")
        checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()}
        self.projection_module.load_state_dict(checkpoint)
        self.projection_module = self.projection_module.to(self.device).half()
        
    def forward(self, last_minus1_k_features, last_k_features, patch_coords_pt, text_embeddings):       
        if torch.isnan(last_minus1_k_features).any() or torch.isnan(last_k_features).any() or torch.isnan(patch_coords_pt).any() or torch.isnan(text_embeddings).any():
            raise ValueError("Input tensors contain NaN")
        
        patch_coords_pt = patch_coords_pt.half().to(self.device)
        if torch.isnan(patch_coords_pt).any():
            raise ValueError("patch_coords_pt contains NaN")

        text_embeddings = text_embeddings.squeeze(0)
        
        # Generate coords embeddings
        coord_embed = self.coord_mlp(patch_coords_pt).unsqueeze(1)
        if torch.isnan(coord_embed).any():
            raise ValueError("coord_embed contains NaN")
        
        # Add patch and coords embeddings --> stack to get the image-level embeddings
        image_level_embeddings_kMinus1 = last_minus1_k_features + coord_embed[:-1, :, :]        
        image_level_embeddings_kth = (last_k_features + coord_embed[-1, :, :])
        image_level_embeddings_kMinus1_concat = torch.cat([t for t in image_level_embeddings_kMinus1], dim=0)
        
        combined_image_embeddings = torch.cat([image_level_embeddings_kMinus1_concat, image_level_embeddings_kth], dim=0).unsqueeze(0) #torch.Size([1, 729, 1152])
        if torch.isnan(combined_image_embeddings).any():
            raise ValueError("combined_image_embeddings contains NaN")
                
        # Pass the resulting embeddings through the projection module
        # projected_embeddings = self.projection_module(combined_image_embeddings.half()).squeeze(0)
        projected_embeddings = self.projection_module(combined_image_embeddings).squeeze(0)       
        
        
        # Ensure no NaNs in projected embeddings
        if torch.isnan(projected_embeddings).any():
            raise ValueError("projected_embeddings contains NaN")
        if torch.isinf(projected_embeddings).any():
            raise ValueError("projected_embeddings contains Inf")
        
        loss = contrastive_loss(projected_embeddings, text_embeddings)

        return loss
class MLPBlock(nn.Module):
    def __init__(self, input_dim: int, embedding_dim: int) -> None:
        super().__init__()
        self.lin1 = nn.Linear(input_dim, embedding_dim * 2)  # embedding_dim = 2 --> (row, col), mlp_dim = 1152*2 = 2304
        self.lin2 = nn.Linear(embedding_dim * 2, embedding_dim)  # embedding_dim = 1152 --> (output embedding vector)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lin2(self.act(self.lin1(x)))
class ProjectionModule(nn.Module):
    def __init__(self, mm_hidden_size, hidden_size):
        super(ProjectionModule, self).__init__()

        # Directly set up the sequential model
        self.model = nn.Sequential(
            nn.Linear(mm_hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self, x):
        return self.model(x)
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
        
def contrastive_loss(image_embeddings, text_embeddings, temperature=1.0):
    image_embeddings = nn.functional.normalize(image_embeddings, dim=-1)
    text_embeddings = nn.functional.normalize(text_embeddings, dim=-1)

    print("image_embeddings:", image_embeddings)
    print("text_embeddings:", text_embeddings)
    
    logits = (text_embeddings @ image_embeddings.T) /  temperature
    images_similarity = image_embeddings @ image_embeddings.T
    texts_similarity = text_embeddings @ text_embeddings.T
    targets = F.softmax((images_similarity + texts_similarity) / 2 * temperature, dim=-1)
    
    texts_loss = cross_entropy(logits, targets, reduction='none')
    images_loss = cross_entropy(logits.T, targets.T, reduction='none')
    loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
    return loss.mean()
coord_mlp = MLPBlock(2, 1152)
coord_mlp = coord_mlp.to('cuda').half()

projection_module = ProjectionModule(1152, 4096)

vlm = CustomVLM(coord_mlp, projection_module)

for param in vlm.parameters():
    param.requires_grad = True

vlm.train()

optimizer = optim.AdamW(list(vlm.parameters()), lr=1e-5)
for epoch in range(num_epochs):
    epoch_loss = 0.0  # Initialize epoch_loss to accumulate loss for the entire epoch
    
    # Open text file for current epoch
    with open(f"training_logs/epoch_{epoch+1}_loss.txt", "a") as loss_file:
    
        for batchID, batch in tqdm(enumerate(data_loader)):
            last_minus1_k_features, last_k_features, patch_coords_pt, text_embeddings = batch
            last_minus1_k_features = last_minus1_k_features.squeeze(0)
            last_k_features = last_k_features.squeeze(0)
            patch_coords_pt = patch_coords_pt.squeeze(0)
            
            loss = vlm(last_minus1_k_features, last_k_features, patch_coords_pt, text_embeddings)            
                        
            # Backward pass and update weights
            optimizer.zero_grad(set_to_none=True)
            loss.backward()

            # Clip gradients using L2 norm (default)
            max_norm = 1.0  # Adjust this value as needed
            nn_utils.clip_grad_norm_(vlm.parameters(), max_norm, norm_type=2.0)
            
            optimizer.step()
            # scheduler.step()

            epoch_loss += loss.item()  # Accumulate loss for the current batch
            
            # Write loss for current batch to file
            loss_file.write(f"Epoch {epoch+1}, Batch {batchID+1}, Loss: {loss.item():.4f}\n")
      
        epoch_loss /= len(data_loader)  # Calculate average loss for the epoch

        # Print the epoch and average loss
        print(f"\nEpoch: {epoch + 1}/{num_epochs}, Average Loss: {epoch_loss:.4f}")

This issue is solved. The network layers were of dtype float16. I switched them back to float32.