Dice loss becoming negative

Hey, I am training a simple Unet on dice and BCE loss on the Salt segmentation challenge on Kaggle. My model’s dice loss is going negative after awhile and soon after so does the BCE loss . In this example, I pick a dataset of only 5 examples and plan to overfit.


def dice(input, taget):
    smooth=.001
    input=input.view(-1)
    target=taget.view(-1)
    
    return(1-2*(input*target).sum()/(input.sum()+taget.sum()+smooth))
        
batch_size=10
optimizer = torch.optim.Adam(net.parameters(), lr=10e-3)
criterion = nn.BCEWithLogitsLoss()


dataset=DatasetSalt(limit_paths=10)
dataloader=DataLoader(dataset,batch_size, shuffle=True, num_workers=2)
net = UNet()


def train():
    for idx, batch_data in enumerate(dataloader):
        x, target=batch_data['image'].float(),batch_data['label'].float()


        optimizer.zero_grad()
        output = net(x)
        output.squeeze_(1)


        bce_loss = criterion(output, target)


        dice_loss = dice(output, target)
        loss = bce_loss + dice_loss 
        loss.backward()
        optimizer.step()
        
        
        
        print('Epoch {}, loss {}, bce {}, dice {}'.format(
            epoch, loss.item(), bce_loss.item(), dice_loss.item()))


    

for i in range(0,1000):
    train()

Output:

I stop the training here and the results seems quite poor.

Update: Only training network on BCE works great. But when dice is added, loss goes to negative and both reset to a high number and start going to negative again and cycle like this.

The inputs parameter in the dice function should be between 0-1 but in your case it is not.

2 Likes

your dice function seem to handle sample ouput and target rather than batch,

Can you explain this in detail? I mean batch size implementation of the dice loss.

Thanks