Getting NaN training and validation loss when training BERT model on pytorch

I am training a pretrained BERT model for a NER task. When I configured the device to cuda, it causes the gradients to backpropagate and output as NaNs. This does not happen when the device is configured on cpu or mps(I am using Mac M1 chip). I am not sure what could be the reason behind my code that would have caused it. Can anyone offer advice to point me on the right direction for this?

This is my training loop. Validation loop is somewhat similar to it, with self.model and self.classifier set to .eval()

def _train_single_epoch(
        self, train_dataloader: DataLoader, optimizer=None
    ) -> Tuple[Dict, Dict, float]:

        self.model.to(self.device)
        self.classifier.to(self.device)
        self.model.train()
        self.classifier.train()
        train_loss = 0.0
        true_labels, pred_labels = [], []
        for train_id, train_mask, train_label, report_id in tqdm(train_dataloader):
            # Forward pass
            input_id = train_id.to(self.device).squeeze(1).to(self.device)
            mask = train_mask.to(self.device).squeeze(1).to(self.device)
            train_label = train_label.to(self.device)
            report_id = report_id.to(self.device)

            # Zero gradients
            if optimizer is not None:
                optimizer.zero_grad()
            loss, logits, _ = self.forward(input_id, mask, train_label, report_id)

            # Update train loss
            train_loss += loss.item()
            preds = logits.argmax(dim=-1)
            true_labels.extend(train_label.cpu().numpy().tolist())
            pred_labels.extend(preds.cpu().numpy().tolist())

            # Backprogragate
            loss.backward()

            # Update model parameter based with respect to gradient
            if optimizer is not None:
                optimizer.step()
        train_loss /= len(train_dataloader)
        return train_loss

This is my forward method used to compute the loss and logits predictions:

def forward(
        self,
        input_id: torch.Tensor,
        mask: torch.Tensor,
        label_tag: torch.Tensor,
        report_ids: torch.Tensor,
        is_inference: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        # Unpack the batch size, number of chunks, and chunk size
        if len(input_id.size()) == 2:
            batch_size, chunk_size = input_id.size()
            num_chunks = 1
        else:
            batch_size, num_chunks, chunk_size = input_id.size()

        # Reshape into (batch_size * num_chunks, chunk_size)
        self.logger.debug("Reshaping into (batch_size * num_chunks, chunk_size)...")
        input_id = input_id.view(-1, chunk_size)
        mask = mask.view(-1, chunk_size)
        label_tag = label_tag.view(-1, chunk_size)
        report_ids = report_ids.view(-1, chunk_size)

        self.logger.debug("Getting top-layer of pre-trained model...")
        output = self.model(input_ids=input_id, attention_mask=mask)

        # Get the top layer of pre-trained model
        logits = self.classifier(output[0])
        loss = None

        if not is_inference:
            self.logger.debug("Calculating loss...")
            loss_fn = self.criterion
            # Create a boolean tensor indicating which elements in the flattened mask
            # equals 1
            active_loss = mask.view(-1) == 1

            # Reshape logits for all tokens into a 2D tensor with size (total_tokens,
            # num_labels)
            active_logits = logits.view(-1, self.num_labels)

            # Replace ignored labels (where active_loss is False) with the -100 integer
            # label
            active_labels = torch.where(
                active_loss,
                label_tag.view(-1),
                torch.tensor(self.tags["-100"]).type_as(label_tag),
            )
            self.logger.debug(f"Unique labels: {torch.unique(active_labels)}")
            self.logger.debug(
                f"Number of unique labels in tensor: {torch.unique(active_labels).numel()}"
            )
            self.logger.debug(
                f"Number of unique labels in num_labels: {self.num_labels}"
            )

            # Ensure active_labels are within the correct range
            if torch.unique(active_labels).numel() > self.num_labels:
                self.logger.error(
                    f"Label {active_labels.max()} is out "
                    f"of bounds for {self.num_labels} classes."
                )
                raise ValueError(
                    f"Label {active_labels.max()} is out "
                    f"of bounds for {self.num_labels} classes."
                )
            self.logger.debug(
                f"Size of active_logits:{active_logits.size()}, "
                f"size of active_labels: {active_labels.size()}"
            )
            loss = loss_fn(active_logits, active_labels)

        self.logger.debug("Reshaping logits...")
        # Reshape logits into single dimension in forward pass
        chunked_logits = logits.view(
            batch_size, num_chunks, chunk_size, self.num_labels
        )
        output_report_ids = report_ids.view(batch_size, num_chunks, chunk_size)

        return loss, chunked_logits, output_report_ids

This is a snippet of what happens when I print out the gradients of each layer in the BERT transformer during back propagation:

Gradient for                       
                             encoder.layer.0.attention.self.que                 
                             ry.weight contains NaNs:                           
                             tensor([[nan, nan, nan,  ..., nan,                 
                             nan, nan],                                         
                                     [nan, nan, nan,  ..., nan,                 
                             nan, nan],                                         
                                     [nan, nan, nan,  ..., nan,                 
                             nan, nan],                                         
                                     ...,                                       
                                     [nan, nan, nan,  ..., nan,                 
                             nan, nan],                                         
                                     [nan, nan, nan,  ..., nan,                 
                             nan, nan],                                         
                                     [nan, nan, nan,  ..., nan,                 
                             nan, nan]], device='cuda:0')                       
Gradient for                       
                             encoder.layer.0.attention.self.que                 
                             ry.bias contains NaNs:                             
                             tensor([nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan,                           
                                     nan, nan, nan, nan, nan,                   
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan, nan, nan,                 
                             nan, nan, nan, nan, nan],                          
                                    device='cuda:0')

The format of my tensors are in torch.Long(64-bit integer) format as I need to have the predictions(logits) returned as a whole integer number, which I initially thought could have caused this problem. But I have tried to use other torch tensor formats such as 32-bit integer but it’s still not resolving it. Optimizer and its parameters used are as follows:

optimizer:
    type: 'AdamW'
    params:
      lr: 1e-5
      weight_decay: 0.005
      betas: [0.9, 0.999]
      eps: 1e-8

Torch tensor types URL: torch.Tensor — PyTorch 2.4 documentation

Anyone able to help out on this?

Lzwk, I ran into this same issue when training a modernBERT model: worked on cpu and mps, but failed on cuda. Not sure if you ever found a solution to your problem, but I was able to workaround it by setting model = model.double() to force float64. Not ideal though so I’d love to know if you found another solution.