Dice score changes for the same input tensors after reshaping

I’m calculating the Dice score to evaluate my model for a binary image segmentation problem.

The function I wrote in PyTorch is:

def dice_score_reduced_over_batch(x, y, smooth=1):
    assert x.ndim == y.ndim
    # reduction over all axes except 0 i.e. batch
    axes = tuple([i for i in range(1, x.ndim)])

    intersection = torch.abs((x * y).sum(dim=axes))
    union = torch.abs(x.sum(dim=axes)) + torch.abs(y.sum(dim=axes))
    dice = torch.mean(2. * (intersection + smooth) / (union + smooth), dim=0)
    return dice

The input tensors x and y have the shape [batch_size, nChannel, height, width] where nChannel=1 since ground truth is a 2d binary mask. The standard way to calculate the dice score is to compute it along the batch axis and taking the mean value at the end (Right?). I found that the score is affected by the way inputs are flattened.

╔═══════════════════╦══════════════════╦════════╗
║ input tensor      ║ flattened tensor ║ dice   ║
╠═══════════════════╬══════════════════╬════════╣
║ [64, 1, 128, 128] ║ -                ║ 0.2754 ║
╠═══════════════════╬══════════════════╬════════╣
║ [64, 1, 128, 128] ║ [64, 16384]      ║ 0.2754 ║
╠═══════════════════╬══════════════════╬════════╣
║ [64, 1, 128, 128] ║ [1, 1048576]     ║ 0.3121 ║
╚═══════════════════╩══════════════════╩════════╝

My best guess was this difference is due to the way values are being averaged but it’s not the case. The code must return the exact same answer irrespective of the arrangement/shape of the input data. How this behavior can be explained? What’s the best way to avoid it?

Hi Stark!

Flattening [64, 1, 128, 128] to [1, 1048576] mixes your batch
dimension in with your image dimensions.

Based on your comments and the comment in your code (rather than
on my absent knowledge of the Dice score), it seems that the Dice
score should be independent of how the image dimensions are
rearranged, but that the batch dimension plays a different role.

Note for example, that

 intersection = torch.abs((x * y).sum(dim=axes))

performs the image sum before calling abs(). If you also sum over
the batch before calling abs() (and negative values are present), you’ll
get a different value.

Similarly,

dice = torch.mean(2. * (intersection + smooth) / (union + smooth), dim=0)

takes the batch mean after performing the division. If you take the
batch mean (by flattening nBatch = 64 away) before you perform
the division, you will get a different answer.

Best.

K. Frank