I am training a pretrained BERT model for a NER task. When I configured the device to cuda, it causes the gradients to backpropagate and output as NaNs. This does not happen when the device is configured on cpu or mps(I am using Mac M1 chip). I am not sure what could be the reason behind my code that would have caused it. Can anyone offer advice to point me on the right direction for this?
This is my training loop. Validation loop is somewhat similar to it, with self.model and self.classifier set to .eval()
def _train_single_epoch(
self, train_dataloader: DataLoader, optimizer=None
) -> Tuple[Dict, Dict, float]:
self.model.to(self.device)
self.classifier.to(self.device)
self.model.train()
self.classifier.train()
train_loss = 0.0
true_labels, pred_labels = [], []
for train_id, train_mask, train_label, report_id in tqdm(train_dataloader):
# Forward pass
input_id = train_id.to(self.device).squeeze(1).to(self.device)
mask = train_mask.to(self.device).squeeze(1).to(self.device)
train_label = train_label.to(self.device)
report_id = report_id.to(self.device)
# Zero gradients
if optimizer is not None:
optimizer.zero_grad()
loss, logits, _ = self.forward(input_id, mask, train_label, report_id)
# Update train loss
train_loss += loss.item()
preds = logits.argmax(dim=-1)
true_labels.extend(train_label.cpu().numpy().tolist())
pred_labels.extend(preds.cpu().numpy().tolist())
# Backprogragate
loss.backward()
# Update model parameter based with respect to gradient
if optimizer is not None:
optimizer.step()
train_loss /= len(train_dataloader)
return train_loss
This is my forward method used to compute the loss and logits predictions:
def forward(
self,
input_id: torch.Tensor,
mask: torch.Tensor,
label_tag: torch.Tensor,
report_ids: torch.Tensor,
is_inference: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Unpack the batch size, number of chunks, and chunk size
if len(input_id.size()) == 2:
batch_size, chunk_size = input_id.size()
num_chunks = 1
else:
batch_size, num_chunks, chunk_size = input_id.size()
# Reshape into (batch_size * num_chunks, chunk_size)
self.logger.debug("Reshaping into (batch_size * num_chunks, chunk_size)...")
input_id = input_id.view(-1, chunk_size)
mask = mask.view(-1, chunk_size)
label_tag = label_tag.view(-1, chunk_size)
report_ids = report_ids.view(-1, chunk_size)
self.logger.debug("Getting top-layer of pre-trained model...")
output = self.model(input_ids=input_id, attention_mask=mask)
# Get the top layer of pre-trained model
logits = self.classifier(output[0])
loss = None
if not is_inference:
self.logger.debug("Calculating loss...")
loss_fn = self.criterion
# Create a boolean tensor indicating which elements in the flattened mask
# equals 1
active_loss = mask.view(-1) == 1
# Reshape logits for all tokens into a 2D tensor with size (total_tokens,
# num_labels)
active_logits = logits.view(-1, self.num_labels)
# Replace ignored labels (where active_loss is False) with the -100 integer
# label
active_labels = torch.where(
active_loss,
label_tag.view(-1),
torch.tensor(self.tags["-100"]).type_as(label_tag),
)
self.logger.debug(f"Unique labels: {torch.unique(active_labels)}")
self.logger.debug(
f"Number of unique labels in tensor: {torch.unique(active_labels).numel()}"
)
self.logger.debug(
f"Number of unique labels in num_labels: {self.num_labels}"
)
# Ensure active_labels are within the correct range
if torch.unique(active_labels).numel() > self.num_labels:
self.logger.error(
f"Label {active_labels.max()} is out "
f"of bounds for {self.num_labels} classes."
)
raise ValueError(
f"Label {active_labels.max()} is out "
f"of bounds for {self.num_labels} classes."
)
self.logger.debug(
f"Size of active_logits:{active_logits.size()}, "
f"size of active_labels: {active_labels.size()}"
)
loss = loss_fn(active_logits, active_labels)
self.logger.debug("Reshaping logits...")
# Reshape logits into single dimension in forward pass
chunked_logits = logits.view(
batch_size, num_chunks, chunk_size, self.num_labels
)
output_report_ids = report_ids.view(batch_size, num_chunks, chunk_size)
return loss, chunked_logits, output_report_ids
This is a snippet of what happens when I print out the gradients of each layer in the BERT transformer during back propagation:
Gradient for
encoder.layer.0.attention.self.que
ry.weight contains NaNs:
tensor([[nan, nan, nan, ..., nan,
nan, nan],
[nan, nan, nan, ..., nan,
nan, nan],
[nan, nan, nan, ..., nan,
nan, nan],
...,
[nan, nan, nan, ..., nan,
nan, nan],
[nan, nan, nan, ..., nan,
nan, nan],
[nan, nan, nan, ..., nan,
nan, nan]], device='cuda:0')
Gradient for
encoder.layer.0.attention.self.que
ry.bias contains NaNs:
tensor([nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan],
device='cuda:0')
The format of my tensors are in torch.Long(64-bit integer) format as I need to have the predictions(logits) returned as a whole integer number, which I initially thought could have caused this problem. But I have tried to use other torch tensor formats such as 32-bit integer but it’s still not resolving it. Optimizer and its parameters used are as follows:
optimizer:
type: 'AdamW'
params:
lr: 1e-5
weight_decay: 0.005
betas: [0.9, 0.999]
eps: 1e-8
Torch tensor types URL: torch.Tensor — PyTorch 2.4 documentation