MONAI tutorial debugging

Your shapes seem to work for me as seen here:

output = torch.randn(4, 51, 96, 96, 96, requires_grad=True)
target = torch.randint(0, 51, (4, 1, 96, 96, 96))

criterion = monai.losses.DiceLoss(softmax=True, to_onehot_y=True)
loss = criterion(output, target)

Could you compare my code snippet with yours and check where the difference might be coming from?