Hey, I am training a simple Unet on dice and BCE loss on the Salt segmentation challenge on Kaggle. My model’s dice loss is going negative after awhile and soon after so does the BCE loss . In this example, I pick a dataset of only 5 examples and plan to overfit.
def dice(input, taget):
smooth=.001
input=input.view(-1)
target=taget.view(-1)
return(1-2*(input*target).sum()/(input.sum()+taget.sum()+smooth))
batch_size=10
optimizer = torch.optim.Adam(net.parameters(), lr=10e-3)
criterion = nn.BCEWithLogitsLoss()
dataset=DatasetSalt(limit_paths=10)
dataloader=DataLoader(dataset,batch_size, shuffle=True, num_workers=2)
net = UNet()
def train():
for idx, batch_data in enumerate(dataloader):
x, target=batch_data['image'].float(),batch_data['label'].float()
optimizer.zero_grad()
output = net(x)
output.squeeze_(1)
bce_loss = criterion(output, target)
dice_loss = dice(output, target)
loss = bce_loss + dice_loss
loss.backward()
optimizer.step()
print('Epoch {}, loss {}, bce {}, dice {}'.format(
epoch, loss.item(), bce_loss.item(), dice_loss.item()))
for i in range(0,1000):
train()
Output:
I stop the training here and the results seems quite poor.
Update: Only training network on BCE works great. But when dice is added, loss goes to negative and both reset to a high number and start going to negative again and cycle like this.