MONAI tutorial debugging

Hello,
I am interested in using PyTorch with @MONAI to segment multiple organs in fetal mouse scans. I’m a beginner and started by following the tutorial 3D Multi-organ Segmentation with UNETR. I have this working on my system with the provided sample data.

Now I’d like to try adapting the example to my own data, with 50 labels. I am getting a run time error on calculation of the loss function, I believe related to the dimension of the labels, although the shape of the tensors passed to loss_function look as expected. Any advice on debugging this would be much appreciated.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_35590/2004812791.py in <module>
      1 while global_step < max_iterations:
----> 2     global_step, dice_val_best, global_step_best = train(
      3         global_step, train_loader, dice_val_best, global_step_best
      4     )
      5 model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))

/tmp/ipykernel_35590/1903283385.py in train(global_step, train_loader, dice_val_best, global_step_best)
     38         print('logit_map',logit_map.shape)
     39         print('y',y.shape)
---> 40         loss = loss_function(logit_map, y)
     41         loss.backward()
     42         epoch_loss += loss.item()

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/.local/lib/python3.8/site-packages/monai/losses/dice.py in forward(self, input, target)
    719             raise ValueError("the number of dimensions for input and target should be the same.")
    720 
--> 721         dice_loss = self.dice(input, target)
    722         ce_loss = self.ce(input, target)
    723         total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/.local/lib/python3.8/site-packages/monai/losses/dice.py in forward(self, input, target)
    144                 warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
    145             else:
--> 146                 target = one_hot(target, num_classes=n_pred_ch)
    147 
    148         if not self.include_background:

~/.local/lib/python3.8/site-packages/monai/networks/utils.py in one_hot(labels, num_classes, dtype, dim)
     89 
     90     o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
---> 91     labels = o.scatter_(dim=dim, index=labels.long(), value=1)
     92 
     93     return labels

RuntimeError: CUDA error: device-side assert triggered

Based on the stacktrace it seems you are using a Dice loss, which fails in the one_hot call on the targets (or rather the scatter_ operation). Could you post the shapes of logit_map and y in this line of code:

loss = loss_function(logit_map, y)

(I see that you’ve already added the print statements so you are on the right path to debug it :wink: )

Thanks, for my data with 50 labels the shapes are:

logit_map: torch.Size([4, 51, 96, 96, 96])
y: torch.Size([4, 1, 96, 96, 96])

I compared this to the original example using the BTCV challenge data with 13 labels:
logit_map: torch.Size([4, 14, 96, 96, 96])
y: torch.Size([4, 1, 96, 96, 96])

Your shapes seem to work for me as seen here:

output = torch.randn(4, 51, 96, 96, 96, requires_grad=True)
target = torch.randint(0, 51, (4, 1, 96, 96, 96))

criterion = monai.losses.DiceLoss(softmax=True, to_onehot_y=True)
loss = criterion(output, target)

Could you compare my code snippet with yours and check where the difference might be coming from?

Thanks, this is very helpful. My label array contains two non-consecutive values (0-49, 102,103). It looks like that’s probably causing this indexing error.

Yes, this would create an indexing error.
Also note that you are apparently dealing with 52 labels, which doesn’t fit the model output shape.
In any case remap the target indices to the expected range in [0, nb_classes-1].

1 Like