Constant Segmentation loss

Hey, I am training a simple Unet on dice and BCE loss on the Salt segmentation challenge on Kaggle. My model’s loss is not changing at all. In this example, I pick a dataset of only 5 examples and keep interacting through and get a constant loss. My gradients are not getting backdroped I think, what can I do?

class DatasetSalt(Dataset):
    
    def __init__(self, file_path='/Users/admin/deepschool.io/salt/images/*', transform=None,limit_paths=0):
        self.path = glob.glob(file_path)
        
        self.path=self.path[:limit_paths]
        self.transform = transform
        
    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, index):
        
        images = cv2.imread(self.path[index])
        images=images[:-1,:-1,:]
        images=torch.from_numpy(images)
        images.transpose_(1,2)
        images.transpose_(0,1)
        
        mks='/Users/admin/deepschool.io/salt/masks/'+self.path[index][39:]
        
        labels=cv2.imread(mks)[:-1,:-1,0]
        
        sample={'image': images,'label': torch.from_numpy(labels)}
        
        if self.transform is not None:
            image = self.transform(image)
            
        return sample

import torch.optim as optim

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=10)

def dice(input, taget):
    smooth=.001
    input=input.view(-1)
    target=taget.view(-1)
    
    return(1-2*(input*target).sum()/(input.sum()+taget.sum()+smooth))
    

batch_size=5
net=UNet()
def train():
    
    list_dice=[]
    cross=[]
    
    dataset=DatasetSalt(limit_paths=5)
    dataloader=DataLoader(dataset,batch_size, shuffle=True, num_workers=2)
    
    for idx, batch_data in enumerate(dataloader):
        inputs, labels=batch_data['image'].float(),batch_data['label'].float()
        
        
        
        
        out=net(inputs)

        
        BCE=criterion(out.view(batch_size,100,100),labels.float())
        dice_loss=dice(out.view(batch_size,100,100),labels.float())
        
        list_dice.append(dice_loss)
        cross.append(BCE)
        
        avg1=sum(list_dice)/len(list_dice)
        avg2=sum(cross)/len(cross)

        print(avg1[0].item(),avg2[0].item())# print average loss until now

        loss=BCE+dice_loss
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Output:
(0.7458401918411255, 0.568533182144165)
(0.7458400726318359, 0.5685329437255859)
(0.7458401322364807, 0.5685330629348755)
(0.7458401918411255, 0.5685333609580994)
(0.7458401918411255, 0.568533182144165)
(0.7458400726318359, 0.5685329437255859)
(0.7458401322364807, 0.5685330629348755)
(0.7458401918411255, 0.5685333609580994)
        
        
        
    

limit_paths just limits the dataset to a smaller size, I tried training on only BCE or dice and even changing lr form 10-0.0001 with no luck. The changes to the loss are very tiny. I trimmed the image from 101 101 to 100 100 to make my network symmetric for unet. I think the error may be the loss function. My mask is 101 101 while input 101 101 3

Code for unet is required:




class UNet_down_block(torch.nn.Module):
    def __init__(self, input_channel, output_channel, down_size):
        super(UNet_down_block, self).__init__()
        self.conv1 = torch.nn.Conv2d(input_channel, output_channel, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(output_channel)
        self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(output_channel)
        self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = torch.nn.BatchNorm2d(output_channel)
        self.max_pool = torch.nn.MaxPool2d(2, 2)
        self.relu = torch.nn.ReLU()
        self.down_size = down_size

    def forward(self, x):
        if self.down_size:
            x = self.max_pool(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x

class UNet_up_block(torch.nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel):
        super(UNet_up_block, self).__init__()
        self.up_sampling = torch.nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv1 = torch.nn.Conv2d(input_channel + input_channel, output_channel, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(output_channel)
        self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(output_channel)
        self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = torch.nn.BatchNorm2d(output_channel)
        self.relu = torch.nn.ReLU()
        
#         self.up1=torch.nn.ConvTranspose2d(12,25,3,stride=2,padding=1)

    def forward(self, prev_feature_map, x,k):
#         print('before up',x.size())
        if k!=0:
            x = self.up_sampling(x)
        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        return x


class UNet(torch.nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.down_block1 = UNet_down_block(3, 16, False)
        self.down_block2 = UNet_down_block(16, 32, True)
        self.down_block3 = UNet_down_block(32, 64, True)
        
        self.mid_conv1 = torch.nn.Conv2d(64, 64, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.mid_conv2 = torch.nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(64)
        self.mid_conv3 = torch.nn.Conv2d(64, 64, 3, padding=1)
        self.bn3 = torch.nn.BatchNorm2d(64)

        self.up_block5 = UNet_up_block(32, 64, 32)
        self.up_block6 = UNet_up_block(16, 32, 16)
        self.up_block7 = UNet_up_block(3, 16, 16)

        self.last_conv1 = torch.nn.Conv2d(16, 3, 3, padding=1)
        self.last_bn = torch.nn.BatchNorm2d(3)
        self.last_conv2 = torch.nn.Conv2d(3, 1, 1, padding=0)
        self.relu = torch.nn.ReLU()
        
        self.max_pool = torch.nn.MaxPool2d(2, 2)

    def forward(self, x):
#         ins=x.clone()
        self.x1 = self.down_block1(x)
#         print('self.x1',self.x1.size())
        self.x2 = self.down_block2(self.x1)
#         print('self.x2',self.x2.size())
        self.x3 = self.down_block3(self.x2)
#         print('self.x3',self.x3.size())
         
#         self.mid=self.max_pool(self.x3)    
 

        self.x7 = self.relu(self.bn1(self.mid_conv1(self.x3)))
        self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
        self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))
        
#         print('prev,x',self.x7.size(),self.x3.size())
        
        x = self.up_block5(self.x3, self.x7,k=0)
        x = self.up_block6(self.x2, x,k=1)
        x=self.up_block7(self.x1,x,k=1)
        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.last_conv2(x)
        return x

if __name__ == '__main__':
    net = UNet()#.cuda()
    print(net)

    test_x = Variable(torch.FloatTensor(1, 3, 100, 100))
    out_x = net(test_x)

#     print(out_x.size())

Your learning rate of 10 is really high. Could you try to lower it so something like 1e-3 and try it again?

Yes I have tried that, from 10-1-0.1-0.001-0.0001, I get same reponse. That was just me testing 10 out.

Using some random input and target, your model seems to be able to learn, as the loss decreases:

x = torch.randn(1, 3, 100, 100)
target = torch.randint(0, 2, (1, 1, 100, 100), dtype=torch.float32)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(20):
    optimizer.zero_grad()
    output = net(x)
    bce_loss = criterion(output, target)
    dice_loss = dice(output, target)
    loss = bce_loss + dice_loss 
    loss.backward()
    optimizer.step()
    print('Epoch {}, loss {}, bce {}, dice {}'.format(
        epoch, loss.item(), bce_loss.item(), dice_loss.item()))

Hmm… not for me:

That’s quite strange, as I’m using just the code you’ve posted here.
These are my results:

Epoch 0, loss 1.2339667081832886, bce 0.7199727296829224, dice 0.5139939785003662
Epoch 1, loss 1.220333218574524, bce 0.7162541151046753, dice 0.5040791034698486
Epoch 2, loss 1.2090554237365723, bce 0.7133639454841614, dice 0.49569153785705566
Epoch 3, loss 1.197749137878418, bce 0.7105355262756348, dice 0.487213671207428
Epoch 4, loss 1.1889395713806152, bce 0.7084173560142517, dice 0.4805222153663635
Epoch 5, loss 1.1799421310424805, bce 0.7061132192611694, dice 0.4738289713859558
Epoch 6, loss 1.1723219156265259, bce 0.7041332125663757, dice 0.46818870306015015
Epoch 7, loss 1.1651051044464111, bce 0.7022958993911743, dice 0.4628092050552368
Epoch 8, loss 1.1579382419586182, bce 0.7005487084388733, dice 0.4573895335197449
Epoch 9, loss 1.1515626907348633, bce 0.6988978385925293, dice 0.45266491174697876
Epoch 10, loss 1.145705223083496, bce 0.697318434715271, dice 0.4483868479728699
Epoch 11, loss 1.1399022340774536, bce 0.6957539319992065, dice 0.44414830207824707
Epoch 12, loss 1.1344850063323975, bce 0.6943031549453735, dice 0.4401818513870239
Epoch 13, loss 1.128913402557373, bce 0.6928091645240784, dice 0.4361041784286499
Epoch 14, loss 1.1239606142044067, bce 0.6913862824440002, dice 0.4325743317604065
Epoch 15, loss 1.1189703941345215, bce 0.6899511814117432, dice 0.42901915311813354
Epoch 16, loss 1.1143808364868164, bce 0.6886655688285828, dice 0.4257153272628784
Epoch 17, loss 1.110227108001709, bce 0.6875361800193787, dice 0.42269086837768555
Epoch 18, loss 1.1056132316589355, bce 0.6861843466758728, dice 0.4194289445877075
Epoch 19, loss 1.1010327339172363, bce 0.6848099231719971, dice 0.41622287034988403

Are you sure you’re using exactly this code you’ve shared?

Yes I just restored my kernel and ran only the cells I pasted here… I am confused…

Just to make sure, you are running this code and get a constant loss?
Which PyTorch version are you using? If you are using an older version (< 0.4.0), could you wrap your tensors into Variables and run it again?

This seems to be working! Will compare to see what the mistake was. Probably something getting redefined maybe