Hello,
I am interested in using PyTorch with @MONAI to segment multiple organs in fetal mouse scans. I’m a beginner and started by following the tutorial 3D Multi-organ Segmentation with UNETR. I have this working on my system with the provided sample data.
Now I’d like to try adapting the example to my own data, with 50 labels. I am getting a run time error on calculation of the loss function, I believe related to the dimension of the labels, although the shape of the tensors passed to loss_function
look as expected. Any advice on debugging this would be much appreciated.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_35590/2004812791.py in <module>
1 while global_step < max_iterations:
----> 2 global_step, dice_val_best, global_step_best = train(
3 global_step, train_loader, dice_val_best, global_step_best
4 )
5 model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
/tmp/ipykernel_35590/1903283385.py in train(global_step, train_loader, dice_val_best, global_step_best)
38 print('logit_map',logit_map.shape)
39 print('y',y.shape)
---> 40 loss = loss_function(logit_map, y)
41 loss.backward()
42 epoch_loss += loss.item()
~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/.local/lib/python3.8/site-packages/monai/losses/dice.py in forward(self, input, target)
719 raise ValueError("the number of dimensions for input and target should be the same.")
720
--> 721 dice_loss = self.dice(input, target)
722 ce_loss = self.ce(input, target)
723 total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss
~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/.local/lib/python3.8/site-packages/monai/losses/dice.py in forward(self, input, target)
144 warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
145 else:
--> 146 target = one_hot(target, num_classes=n_pred_ch)
147
148 if not self.include_background:
~/.local/lib/python3.8/site-packages/monai/networks/utils.py in one_hot(labels, num_classes, dtype, dim)
89
90 o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
---> 91 labels = o.scatter_(dim=dim, index=labels.long(), value=1)
92
93 return labels
RuntimeError: CUDA error: device-side assert triggered