Torchgeometry dice loss unexpected result

Good afternoon, I am trying to understand dice loss better to correct some saturation issues in my segmentation code, but even a simple example like below returns unexpected results.

from torchgeometry.losses.dice import dice_loss
pred = torch.tensor([[[0,1],[2,3]]])
pred = torch.nn.functional.one_hot(pred, num_classes= 4).permute(0,3,1,2).float()
targ = torch.tensor([[[0,1],[2,3]]])
print(dice_loss(pred, targ))

tensor(0.5246)

It is my understanding that dice loss is supposed to return 1.0 upon perfect class overlap, but as you can see above, it does not. I set the dtypes as specified (float32 and int64) and input the predictions as 1-hot (N,C,H,W), and input the targets as (N,H,W) as shown above. Is there a reason this should not return 1?

That’s not the case as the docs explain the inputs as:

X expects to be the scores of each class.
Y expects to be the one-hot tensor with the class labels.

where the scores will be passed to F.softmax(input, dim=1) before the dice loss is calculated.

Thank you, the softmax appears to be creating the unexpected results. However, the doc you quoted is still quite confusing:

However three lines down the docs state:

Shape:

  • Input: (N,C,H,W) where C = number of classes.
  • Target: (N,H,W) where each value is 0≤targets[i]≤C−1

(N, H, W) is not a one-hot tensor. In fact when I try to pass Y as a one-hot tensor, dice_loss tells me “Invalid depth shape, we expect BxHxW. Got: torch.Size([1, 4, 2, 2])”

The text is indeed confusing as internally:

# create the labels one hot tensor
target_one_hot = one_hot(target, num_classes=input.shape[1],
                         device=input.device, dtype=input.dtype)

will be used.
I guess the authors try to tell Y is expected to contain class labels as it will represent the one-hot tensor?

I also don’t know if torchgeometry is still supported or if it was replaced by kornia which uses the same docs.

CC @edgarriba

@ptrblck you are right in both of your comments. Y is expected to be the tensors with the labels ids so that Internally creates for you the one hot. And yeah, torchgeometry doesn’t exist anymore (for more than 3/4 years) since evolved to what’s kornia today.

@paulyoung feel free to send a PR to adjust any documentation miss understanding.