Model does not train

Hi all.
I am trying to train a model for segmentation. It turns out that the model dice loss does not reduce than 1. It means that the model is not training. Can you tell me how do I make sure that I am at least able to overfit on the training data to get started? I have tried various optimizers and various learning rates.

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.ins = in_conv(n_channels, 16)  #128x128x128x1  ==> 128x128x128x16
        self.en_1 = Encode3d(16, 32)        #128x128x128x16 ==> 64x64x64x32
        self.en_2 = Encode3d(32, 64)        #64x64x64x32    ==> 32x32x32x64
        self.en_3 = Encode3d(64, 128)       #32x32x32x64    ==> 16x16x16x128
        self.br_4 = Bridge3d(128, 256)      #16x16x16x128   ==> 8x8x8x256
        
        self.de_3 = Decode3d(256, 128)#......#8x8x8x256, 16x16x16x128
        self.de_2 = Decode3d(128, 64)#.......
        self.de_1 = Decode3d(64, 32)#
        self.de_0 = Decode3d(32, 1)
#        self.out = out_conv(24, 1)
    
    def forward(self, x):
        x1 = self.ins(x)
        x2 = self.en_1(x1)
        x3 = self.en_2(x2)
        x4 = self.en_3(x3)
        x5 = self.br_4(x4)
        
        x = self.de_3(x5, x4)
        x = self.de_2(x, x3)
        x = self.de_1(x, x2)
        x = self.de_0(x, x1)
        return x

class in_conv(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size = 3, stride =1, 
               padding = 1, bias = True, activation = 'relu', res = False):
        nn.Module.__init__(self)
        self.input_channels = input_channels
        self.output_channels= output_channels
        self.residual = res
        self.bn_1 = nn.InstanceNorm3d(output_channels, affine = True)
        self.bn_2 = nn.InstanceNorm3d(output_channels, affine = True)
        self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
                               stride = 1, padding = 1, bias = True)
        self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 3,
                               stride = 1, padding = 1, bias = True)
        
    def forward(self, x):
#        print("Input conv:", x.shape)
        if self.residual == True:
            skip = x
        x = self.conv1(x)
        x = self.bn_1(x)
        x = F.leaky_relu(x, negative_slope = 0.01, inplace = True)
        x = self.conv2(x)
        if self.residual == True:
            x = x + skip
#        print("Exiting Input Conv:", x.shape)
        return x

############# ENCODING MODULE #######################################
class Encode3d(nn.Module):
    def __init__(self, input_channels, output_channels, res = False):
        super(Encode3d, self).__init__()
        self.input_channels = input_channels
        self.output_channels= output_channels
        self.residual = res
        self.bn_1 = nn.InstanceNorm3d(input_channels, affine = True)
        self.bn_2 = nn.InstanceNorm3d(output_channels, affine = True)
        self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
                               stride = 1, padding = 1, bias = True)
        self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 3,
                               stride = 1, padding = 1, bias = True)
        self.maxpool = nn.MaxPool3d(2)
        
    def forward(self, x):
#        print("Encoding:", x.shape)
        x = self.maxpool(x)
        if self.residual == True:
            skip = x
        x = self.bn_1(x)
        x = F.leaky_relu(x)
        x = self.conv1(x)
        x = self.bn_2(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        if self.residual == True:
            x = x + skip
#        print("Exiting Encoding:", x.shape)
        return x
    
class Bridge3d(nn.Module):
    def __init__(self, input_channels, output_channels, res = False):
        super(Bridge3d, self).__init__()
        self.input_channels = input_channels
        self.output_channels= output_channels
        self.residual = res
        self.bn_1 = nn.InstanceNorm3d(input_channels, affine = True)
        self.bn_2 = nn.InstanceNorm3d(output_channels, affine = True)
        self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
                               stride = 1, padding = 1, bias = True)
        self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 3,
                               stride = 1, padding = 1, bias = True)
        self.maxpool = nn.MaxPool3d(2)
        
    def forward(self, x):
#        print("Bridging", x.shape)
        x = self.maxpool(x)
        if self.residual == True:
            skip = x
        x = self.bn_1(x)
        x = F.leaky_relu(x)
        x = self.conv1(x)
        x = self.bn_2(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        if self.residual == True:
            x = x + skip
#        print("Exiting Bridging:", x.shape)
        return x

class Decode3d(nn.Module):
    def __init__(self, input_channels, output_channels, conv_bias=True,
                 lrelu_inplace=True, Trilinear = True, res = False):
        super(Decode3d, self).__init__()
        if Trilinear:
            self.up = nn.Upsample(scale_factor = 2, mode = 'trilinear', align_corners = True)
        else:
            self.up = nn.ConvTranspose3d(input_channels, input_channels,
                        kernel_size = 3, stride = 2, output_padding = 1)
        self.lrelu_inplace = lrelu_inplace
        self.conv_bias = conv_bias
        self.input_channels = input_channels
        self.output_channels= output_channels
        self.residual = res
        self.conv1 = nn.Conv3d(int(input_channels*1.5), output_channels, kernel_size = 3,
                               stride = 1,padding = 1, bias=conv_bias)
        self.bn_1 = nn.InstanceNorm3d(int(input_channels*1.5), affine=True)
        self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 1,
                               stride = 1,padding = 0, bias=conv_bias)
        self.bn_2 = nn.InstanceNorm3d(output_channels, affine=True)

    def forward(self, x1, x2):
#        print("Decoding x1: ", x1.shape,"x2: ", x2.shape)
        x = self.up(x1)
#        print("After upconv, X shape: ", x.shape)
        x = torch.cat([x, x2], dim = 1)
#        print("After concatenating X shape: ", x.shape)
        if self.residual == True:
            skip = x
        x = self.bn_1(x)
#        print("Successful_batch_normalization")
        x = F.leaky_relu(x)
        x = self.conv1(x)
        x = self.bn_2(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        if self.residual == True:
            x = x + skip
#        print("Exiting decoding", x.shape)
        return x

class out_conv():
    def __init__(self, input_channels, output_channels = 1, kernel_size = 3,
          stride = 1, padding = 1, bias = True, activation = 'relu', Trilinear = True,
          lrelu_inplace = True, conv_bias = True, res = True):
        nn.Module.__init__(self)
        if Trilinear:
            self.up = nn.Upsample(scale_factor = 2, mode = 'trilinear', align_corners = True)
        else:
            self.up = nn.ConvTranspose3d(input_channels, output_channels = input_channels, kernel_size = 3,
                                         stride = 2, output_padding = 1)
        self.lrelu_inplace = lrelu_inplace
        self.conv_bias = conv_bias
        self.input_channels = input_channels
        self.output_channels= output_channels
        self.residual = res  
        self.bn_1 = nn.InstanceNorm3d(int(input_channels*15), affine = True)
        self.conv1 = nn.Conv3d(int(input_channels*1.5), output_channels, kernel_size = 3,
                                  stride = 1, padding = 1, bias = True)
        self.bn_2 = nn.InstanceNorm3d(input_channels, affine = True)
        self.bn_3 = nn.InstanceNorm3d(output_channels, affine = True)
        self.conv2 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
                                  stride = 1, padding = 1, bias = True)

    def forward(self, x1, x2):
#        print("Out conving x1:", x1.shape, "x2: ", x2.shape)
        x = self.up(x1)
#        print("After upconving, X shape:", x.shape)
        x = torch.cat([x, x2], dim = 1)
#        print("After concatenating X shape: ", x.shape)
        if self.residual == True:
            skip = x
        x = self.bn_1(x)
#        print("Successful_batch_normalization")
        x = F.leaky_relu(x)
        x = self.conv1(x)
        x = self.bn_2(x)
        x = F.leaky_relu(x)
        x = self.conv2(x)
        if self.residual == True:
            x = x + skip
        x = self.bn_3(x)
        x = nn.sigmoid(x, negative_slope = 0.01, inplace = True)
        
        return x

def train(epoch, train_loss_list):
        model.train()
        for batch_idx, (image_name, image, mask, affine) in enumerate(train_loader):
            if params['cuda']:
                image, mask = image.cuda(), mask.cuda()          #Loading images into the GPU and ignoring the affine.
    
            image, mask = Variable(image), Variable(mask)
    
            optimizer.zero_grad()
    
            output = model(image)
    
            loss = criterion(output, mask)
            train_loss_list.append(loss.data.item())
    
            loss.backward()
            optimizer.step()
    
            if batch_idx % int(params['log_interval']) == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage DICE Loss: {:.6f}'.format(
                    epoch, batch_idx * len(image), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.data.item()))
        for param_group in optimizer.param_groups:
            print("Learning rate: ", param_group['lr'])
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'loss_state_dict' : criterion.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.data.item()
            }, './tmp/model.pkl')
        sys.stdout.flush()
criterion = DiceLoss()
for i in range(int(params['max_epochs'])):
        if i % int(params['val_epoch_interval']):
            create_epoch_dir(i, params['epoch'])
        train(i, train_loss_list)
class DiceLoss(nn.Module):

    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, output, mask):

        probs = torch.squeeze(output, 1)
        mask = torch.squeeze(mask, 1)

        intersection = probs * mask
        intersection = torch.sum(intersection, 2)
        intersection = torch.sum(intersection, 1)

        den1 = probs * probs
        den1 = torch.sum(den1, 2)
        den1 = torch.sum(den1, 1)

        den2 = mask * mask
        den2 = torch.sum(den2, 2)
        den2 = torch.sum(den2, 1)

        eps = 1e-8
        dice = 2 * ((intersection + eps) / (den1 + den2 + eps))
        # dice_eso = dice[:, 1:]
        dice_eso = dice

        loss = 1 - torch.sum(dice_eso) / dice_eso.size(0)
        return loss

Here is a rough structure of the model. It is a simple UNet.
Can anyone help me with this?
Thanks.

A simple way testing if your model can reach a global minimum (i.e. overfit) is simply reducing drastically your training data. Train your model in 10 images for a lot of epochs and tell me what happens.

upload training loop and loss function

Updated. Please check again.

Unable to overfit the network for 10 training images.

If your main objective is to debug your code, then, you’ll maybe find useful these tips:

  1. Are you sure your labels are correct?
  2. Did you try other learning rates? smaller? bigger?
  3. Did you normalize your training data ?
  4. Did you try other losses?

If your objective is to get something running and tweaking the architecture to achieve some specific results, I think you should try cloning a GitHub repo and reproducing the results the autor says he/she gets. For instance, I found this repo that might be useful to you. Try running it and understanding what differs from your implementation. For instance, it seems the autor use a BCELoss instead of a custom loss function.