On some examples I see Dice Loss calculated leaving the batch out (inputs.reshape(inputs.shape[0], -1):
def dice_loss(inputs, target):
num = target.size(0)
inputs = inputs.reshape(num, -1)
target = target.reshape(num, -1)
intersection = (inputs * target).sum(1)
union = inputs.sum(1) + target.sum(1)
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum() / num
return 1 - dice
Other examples they flatten the entire tensor (inputs.reshape(-1)):
def dice_loss(inputs, target):
inputs = inputs.reshape( -1)
target = target.reshape(-1)
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum() / num
return 1 - dice
and I am confused