# Dice score changes for the same input tensors after reshaping

I’m calculating the Dice score to evaluate my model for a binary image segmentation problem.

The function I wrote in PyTorch is:

``````def dice_score_reduced_over_batch(x, y, smooth=1):
assert x.ndim == y.ndim
# reduction over all axes except 0 i.e. batch
axes = tuple([i for i in range(1, x.ndim)])

intersection = torch.abs((x * y).sum(dim=axes))
union = torch.abs(x.sum(dim=axes)) + torch.abs(y.sum(dim=axes))
dice = torch.mean(2. * (intersection + smooth) / (union + smooth), dim=0)
return dice
``````

The input tensors `x` and `y` have the shape `[batch_size, nChannel, height, width]` where `nChannel=1` since ground truth is a 2d binary mask. The standard way to calculate the dice score is to compute it along the `batch` axis and taking the mean value at the end (Right?). I found that the score is affected by the way inputs are flattened.

``````╔═══════════════════╦══════════════════╦════════╗
║ input tensor      ║ flattened tensor ║ dice   ║
╠═══════════════════╬══════════════════╬════════╣
║ [64, 1, 128, 128] ║ -                ║ 0.2754 ║
╠═══════════════════╬══════════════════╬════════╣
║ [64, 1, 128, 128] ║ [64, 16384]      ║ 0.2754 ║
╠═══════════════════╬══════════════════╬════════╣
║ [64, 1, 128, 128] ║ [1, 1048576]     ║ 0.3121 ║
╚═══════════════════╩══════════════════╩════════╝
``````

My best guess was this difference is due to the way values are being averaged but it’s not the case. The code must return the exact same answer irrespective of the arrangement/shape of the input data. How this behavior can be explained? What’s the best way to avoid it?

Hi Stark!

Flattening `[64, 1, 128, 128]` to `[1, 1048576]` mixes your batch
dimension in with your image dimensions.

on my absent knowledge of the Dice score), it seems that the Dice
score should be independent of how the image dimensions are
rearranged, but that the batch dimension plays a different role.

Note for example, that

`````` intersection = torch.abs((x * y).sum(dim=axes))
``````

performs the image sum before calling `abs()`. If you also sum over
the batch before calling `abs()` (and negative values are present), you’ll
get a different value.

Similarly,

``````dice = torch.mean(2. * (intersection + smooth) / (union + smooth), dim=0)
``````

takes the batch mean after performing the division. If you take the
batch mean (by flattening `nBatch = 64` away) before you perform
the division, you will get a different answer.

Best.

K. Frank