3D Unet Implementation doesn't overfit

Hello!
I’ve trying to build the original model from 3D Unet paper but when I train the model with only 1 image, it can’t overfit. I’m not sure if I’m missing something or 3D Unet is not good enough to overfit. I’m using different learning rates and Adam optimizer but the max accuracy I get with IoU is 89%.

This is the code I’m using:

import torch.nn.functional as F
import torch.nn as nn
import torch


class Convolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Convolution, self).__init__()

        self.convolution = nn.Conv3d(in_channels, out_channels, kernel_size=3)
        self.batch = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        
        out = F.relu(self.batch(self.convolution(x)))

        return out


class UpConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpConvolution, self).__init__()

        self.up_convolution = nn.ConvTranspose3d(in_channels, out_channels,
                                                 kernel_size=2, stride=2)


    #Center crop
    def crop(self, bridge, up):
        batch_size, n_channels, depth, layer_width, layer_height = bridge.size()
        target_batch_size, target_n_channels, target_depth, target_layer_width, target_layer_height = up.size()

        xy = (layer_width - target_layer_width) //2
        zxy = (depth - target_depth) //2
        # Returns a smaller block which is the same size than the block in the up part
        return bridge[:, :, zxy:(zxy + target_depth), xy:(xy + target_layer_width), xy:(xy + target_layer_width)]


    def forward(self, x, bridge):

        up = self.up_convolution(x)
        # Bridge is the opposite block of the up part
        crop1 = self.crop(bridge, up)
        out = torch.cat((crop1, up), 1)

        return out


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.pooling = nn.MaxPool3d(kernel_size=2, stride=2)

        #Down of unet
        self.conv_1_32 = Convolution(1, 32)
        self.conv_32_64 = Convolution(32, 64)
        self.conv_64_64 = Convolution(64, 64)
        self.conv_64_128 = Convolution(64, 128)
        self.conv_128_128 = Convolution(128, 128)
        self.conv_128_256 = Convolution(128, 256)
        self.conv_256_256 = Convolution(256, 256)
        self.conv_256_512 = Convolution(256, 512)

        #Up of unet
        self.conv_512_512_UpConv = UpConvolution(512, 512)
        self.conv_768_256_Conv = Convolution(768, 256)
        self.conv_256_256_Conv = Convolution(256, 256)
        self.conv_256_256_UpConv = UpConvolution(256, 256)
        self.conv_384_128_Conv = Convolution(384, 128)
        self.conv_128_128_Conv = Convolution(128, 128)
        self.conv_128_128_UpConv = UpConvolution(128, 128)
        self.conv_192_64_Conv = Convolution(192, 64)
        self.conv_64_64_Conv = Convolution(64, 64)
        self.conv_64_1 = nn.Conv3d(64, 1, 1)

    def forward(self, x):
        start = self.conv_1_32(x)
        block1 = self.conv_32_64(start)
        block1_pool = self.pooling(block1)
        block2 = self.conv_64_64(block1_pool)
        block2 = self.conv_64_128(block2)
        block2_pool = self.pooling(block2)
        block3 = self.conv_128_128(block2_pool)
        block3 = self.conv_128_256(block3)
        block3_pool = self.pooling(block3)
        block4 = self.conv_256_256(block3_pool)
        block4 = self.conv_256_512(block4)

        up1 = self.conv_512_512_UpConv(block4, block3)
        up1_conv = self.conv_768_256_Conv(up1)
        up1_conv = self.conv_256_256_Conv(up1_conv)
        up2 = self.conv_256_256_UpConv(up1_conv, block2)
        up2_conv = self.conv_384_128_Conv(up2)
        up2_conv = self.conv_128_128_Conv(up2_conv)
        up3 = self.conv_128_128_UpConv(up2_conv, block1)
        up3_conv = self.conv_192_64_Conv(up3)
        up3_conv = self.conv_64_64_Conv(up3_conv)
        output = self.conv_64_1(up3_conv)

        output = torch.sigmoid(output)

        return output

I assume you are using nn.BCELoss as your criterion.
If so, could you remove the last sigmoid and use nn.BCEWithLogitsLoss?

Let me know, if this helps in any sense or if you are still seeing this behavior.

I’m plotting the predicted image and I see that BCE performs better than BCEwithlogitsloss, so I decided to keep using BCE. The problem is that after 200 epochs, it can’t increase the accuracy. It should be able to overfit really quickly if I’m using only one training image.
I debugged the neural network and I’m getting the same output shape just like in the paper. I’ve seen many 3D Unet and 2D Unet implementations but they are changing the output size to be the same than the input’s and that’s not what they do in the papers.

nn.BCELoss + sigmoid should perform equally well as nn.BCEWithLogitsLoss + logits in the unsaturated area.

The latter approach gives you more numerical stability.
If the former one is performing better, this could mean e.g. that your learning rate is too high and thus the saturated gradients are beneficial for the model.

1 Like

This is the training code I’m using. I had to create my own Resize function since I’m dealign with TIFF images. I use F.interpolate for that.
I’ve been testing the network with only 1 image and I can’t overfit. I also implemented different models from the web and all of them behave in the same way, that’s why I think the problem is on my training, data loader or transformation code.
What do you think could be the problem?

import torch
import unet3D
import numpy as np
import transformations_3_D
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import visdom
import time
from sklearn.metrics import jaccard_score
from torch.functional import F
from dataset_3D import MacularHole
from torch.utils import data
# from images3D_visualizer import ImageVisualizer
from matplotlib import pyplot as plt
import unet3d_github

# Resize function

# Resize function
def labels_resize(tensor, size, normalize=False):
    tensor = F.interpolate(tensor, size=size, mode="nearest")
    if normalize:
        if torch.min(tensor) == 0 and torch.max(tensor) == 0:
            tensor = tensor
        else:
            tensor = (tensor - torch.min(tensor)) / (torch.max(tensor) - torch.min(tensor))

    return tensor


def prediction_resize(tensor, size):
    tensor = F.interpolate(tensor, size=size, mode="trilinear", align_corners=True)
    return tensor


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([transformations_3_D.Resize((116, 132, 132))])

    seed = np.random.randint(0, 1000)

    train_set = MacularHole("dataset_first_tests_45_imgs/", mode='train', seed=seed, transform=transform,
                            training_size=1, valid_size=9, testing_size=9)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=16)
  
    # model = unet3D.UNet()

    model = model.to(device)

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # Lists of data for visdom
    max_train_predictions = []
    train_performance_values = []
    max_valid_predictions = []
    valid_performance_values = []
    train_loss_values = []
    valid_loss_values = []
    test_performance_values = []

    epochs = 5000
    valid_performance_min = 0.0
    train_max_value = 0.0
    valid_max_value = 0.0
    t0 = time.time()
    for i in range(epochs):
        train_loss = 0.0
        valid_loss = 0.0
        jaccard_scores_train = []

        model.train()
        for batch, (images, labels) in enumerate(train_loader):
            images, labels = images.float(), labels.float()
            images, labels = images.to(device), labels.to(device)
            labels = labels_resize(labels, (49, 200, 200))

            optimizer.zero_grad()
            train_predictions = model(images)
            resize_labels = labels_resize(labels, (28, 44, 44), normalize=True)
            train_max_value = torch.max(train_predictions)
            loss = criterion(train_predictions, resize_labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # jaccard scores
            with torch.no_grad():

                train_predictions = prediction_resize(train_predictions, (49, 200, 200))
                # Binarizing labels
                labels = labels != 0
                # Thresholding predictions
                binary_train_predictions = train_predictions > 0.5
                labels = labels.detach().cpu().numpy().astype(int)
                binary_train_predictions = binary_train_predictions.detach().cpu().numpy().astype(int)
                # Flattening numoy arrays for jaccard_score and adding each value to a list
                jaccard_scores_train += [jaccard_score(labels.flatten(), binary_train_predictions.flatten())]

        # calculating loss and performance in each epoch
        train_loss = train_loss / len(train_set)
        train_loss_values.append(train_loss)
        train_performance = np.mean(jaccard_scores_train)
        train_performance_values.append(train_performance)

        print(f'Epoch: {i + 1} |' f'\tTrain loss: {train_loss} |' f'\t| Train performance: {train_performance} |' 
              f'\t| Max training: {train_max_value} |')



if __name__ == '__main__':
    # vis = visdom.Visdom()
    main()

I’m able to overfit the model on random data using:

device = 'cuda'
model = UNet().to(device)
x = torch.randn(1, 1, 116, 132, 132).to(device)
target = torch.randint(0, 2, (1, 1, 28, 44, 44), device=device).float()

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for idx in range(500):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('Iter {}, loss {:.4f}'.format(idx, loss.item()))

In your code it seems you are somehow normalizing the labels.
Does it mean you don’t use a target tensor containing only zeros and ones?
If so, note that binary cross entropy will only yield a zero loss, if the target and predictions are either both (close to) zero or one as described here.