U-Net can't perfectly segment image

hey, i’m using unet with pretrained resnet34 as encoder to segment diabetic retinopathy images, here is my full architectur

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            nn.LeakyReLU(negative_slope=0.01),
            resnet.layer1
        )
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        return x1, x2, x3, x4

and my middle conv is

class MiddleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1)
        self.relu1 = nn.LeakyReLU(negative_slope=0.01)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1)
        self.relu2 = nn.LeakyReLU(negative_slope=0.01)
        self.bn1 = nn.BatchNorm2d(out_channels)

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.bn1(x)

        return x

and my decoder and merge layer

class Dec_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Up-Convolution
        self.upconv = PixelShuffle_ICNR(in_channels, out_channels, scale = 2)
        
        # LeakyReLU
        self.relu = nn.LeakyReLU(negative_slope=0.01)
        
        # Batch normalization
        self.bn = nn.BatchNorm2d(out_channels, track_running_stats=False)
        
        # Basic Block
        self.conv = BasicBlock(in_channels+out_channels , out_channels)
        
    
    def forward(self, inputs, skip):
        up_x = self.upconv(inputs)
        up_x = self.relu(up_x)
        up_x = self.bn(up_x)  
        skip = nn.functional.interpolate(skip, size=up_x.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([up_x, skip], dim=1)
        x = self.conv(x)
  
        return x

class MergeLayer(nn.Module):
    def __init__(self):
        super(MergeLayer, self).__init__()

    def forward(self, x, skip):
        x = torch.cat([x, skip], dim=1)
        return x

class Build_Unet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.encoder = Encoder()

        """ Middle Convolution """
        self.mc1 = MiddleConv(512, 1024)
        self.mc2 = MiddleConv(1024, 512)

        """ Decoder """
        self.decoder1 = Dec_Block(512, 256)
        self.decoder2 = Dec_Block(256, 128)
        self.decoder3 = Dec_Block(128, 64)
        self.decoder4 = Dec_Block(64, 32)

        """ Merge Layer """
        self.merge = MergeLayer()

        """ Segmentation Convolution """
        self.segmentation = nn.Sequential(nn.Conv2d(35, 1, kernel_size=1))

    def forward(self, x):
        """ Encoder """
        x1, x2, x3, x4 = self.encoder(x)
        mc1 = self.mc1(x4)
        mc2 = self.mc2(mc1)

        """ Decoder """
        d1 = self.decoder1(mc2, x4)
        d2 = self.decoder2(d1, x3)
        d3 = self.decoder3(d2, x2)
        d4 = self.decoder4(d3, x1)

        merged = self.merge(d4, x)

        out = self.segmentation(merged)

        return out

i train my model using batch size=2, learning rate=0.001, and epoch 30 with loss function BCE and Adam optimizer, but my dice coefficient getting low and my segmentation result look like this
image
anyone knows how can i increase my model performance? thanks!

Hi Anastasia!

I can’t speak to the correctness of your code nor to the appropriateness
of your model architecture, but I would make the following comments:

You don’t say how big your training set is, but, in general, thirty epochs of
training is not a lot.

I would suggest that you first attempt to overfit your model on a small
training set, say something like sixteen images. You should be able to
train your model – perhaps by running many epochs – so that it makes
very good predictions for the images in your (small) training set.

This won’t be a good model, because – due to overfitting – it won’t make
good predictions for images that are not in your training set.

If you can’t get such an overfitting test to work, you either have a bug
somewhere or your architecture isn’t appropriate to your task.

If you pass this test, then try training on an appropriately large training
set. (As a rule, more data is better.) Train, perhaps for many epochs,
while tracking your loss and performance metrics (e.g., the Dice score)
on both your training set and validation set. Can you train enough that
you get a low loss and acceptable Dice score on both your training and
validation sets? If your validation-set results start getting worse (even if
your training-set results are still getting better), your training has started
to overfit, and you will have to address that issue if your validation-set
results are not yet good enough.

Good Luck!

K. Frank

my training set contains 53 images, and test set contains 27 images, is it too small for training?

That’s very small, as far as datasets are concerned.

Here are some things you could try:

  1. Apply some sensible image transforms and augmentations:

https://pytorch.org/vision/master/transforms.html

  1. Use BCELossWithLogits on the raw model outputs, instead of BCE, for better numerical stability.

  2. When you or I look at these eyeball scans, without any second thought, we map them as a 3d object. Which makes it easier to identify defects in the context of the surface of a spheroid. In contrast, 2d convolutions are limited to detection of textures and 2d edges. Those tend to have issues when something gets rotated on the depth dimension, as is the case in your visual example. Combining the two convolutions in the model may help yield some better model comprehension. So, apply a 3d convolution branch, as detailed here: GitHub - CerebralSeed/Hybrid-3D-UNet: Model for Hybrid 3D UNet .

ah thanks! if I use image transform in pytorch, will the image match the groundtruth automatically?

No. You’ll have to apply an identical transform.

Any transforms that change the location of the defects, should be applied to both the images and ground truth. Transforms that only change the color/blur/etc. but not shifting anything won’t need any changes to the ground truth. You could concat the ground truth on the channel dim before applying those types of transforms, then be sure to separate it after the transform.

like this?

class ProcessDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.transform = transforms.Compose([
                            transforms.Resize(256),
                            transforms.RandomVerticalFlip(),
                            transforms.RandomRotation(degrees=10),
                            transforms.ColorJitter(brightness=0.2, contrast=0.2),
                            transforms.ToTensor()])
        self.transform2 = transforms.Compose([
                            transforms.Resize(256),
                            transforms.RandomVerticalFlip(),
                            transforms.RandomRotation(degrees=10),
                            transforms.ToTensor()])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        img_x = Image.open(self.x[idx])
        img_y = Image.open(self.y[idx]).convert("L")

        img_x = self.transform(img_x)
        img_y = self.transform2(img_y)  

        return img_x, img_y

Not quite. The transforms occur relatively random. And so you want to ensure that each image/ground truth pair are processed together. Otherwise one would be shifted one direction while the other might be shifted in the opposite direction, which won’t help the model be better at finding the ground truth.

Of the above transforms, you likely would not want ColorJitter applied to the ground truths. So you might have self.transforms_data and self.transforms_all delineated, and then applied accordingly.

In the following example, I am going to assume your ground truths only have 1 channel, based on your earlier model. If there are more than 1, you can update accordingly. And I will assume your channels are on dim=0 in the dataset process after transforms.ToTensor() is applied.

class ProcessDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

        self.pre_process = transforms.ToTensor()

        self.transform_data = transforms.Compose([
                            transforms.ColorJitter(brightness=0.2, contrast=0.2)])

        self.transform_all = transforms.Compose([
                            transforms.Resize(256),
                            transforms.RandomVerticalFlip(),
                            transforms.RandomRotation(degrees=10)])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        img_x = Image.open(self.x[idx])
        img_y = Image.open(self.y[idx]).convert("L")

        #First get into the right range of 0 - 1, permute channels first, and put to tensor
        img_x = self.pre_process(img_x)
        img_y = self.pre_process(img_y)

        #Apply resize and shifting transforms to all; this ensures each pair has the identical transform applied
        img_all = torch.cat([img_x, img_y])
        img_all = self.transform_all(img_all)
        
        #Split again and apply any color/saturation/hue transforms to data only
        img_x, img_y = img_all[:-1, ...], img_all[-1:,...]
        img_x = self.transform_data(img_x)  

        return img_x, img_y

ah thanks so much! since my dataset only contains train and test set, do the transform/augmentation only applied on train dataset?

Correct. We just want to give the model enough variety on the train set so it does well on the test set, or any new examples given in deployment.

okay, i’ve tried that augmentation and add epoch to 50, but why my dice coefficient still low?
here’s my training code

for epoch in tqdm(range(EPOCHS)):
    model.train()

    totalTrainLoss = 0
    totalTestLoss = 0
    
    for i, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        output = model(data)

        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        totalTrainLoss += loss

    with torch.no_grad():
        model.eval()

        total_dice = 0

        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            output = model(data)

            loss = criterion(output, target)
            totalTestLoss += loss

            # convert output to binary predictions
            pred = torch.round(output)

            # calculate Dice coefficient using MONAI
            dice = compute_dice(pred, target)
            total_dice += dice

    avg_dice = total_dice / len(test_loader)

    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps
  1. What is your loss function now(i.e. criterion)? Can you show where you define the loss function, and are you making use of pos_weight? (See here: multilabel classification - How to calculate unbalanced weights for BCEWithLogitsLoss in pytorch - Stack Overflow)

  2. Have you tried viewing a few of the transforms to see if they resemble typical deviations in the train and test data? For example, if all the images have the eye facing one direction, then a horizontal flip might not have any benefit. You want to be using the augments to diversify your training dataset, but not outside of the normal bounds of what a scan might look like. So you may need some fine-tuning on the transform settings.

  1. here’s my loss function
criterion = nn.BCEWithLogitsLoss()
  1. no. i’ve not yet viewing the details of transforms, what transforms should i use based on my dataset (fundus image)?

From looking at the ground truths you provided so far, the majority elements are negative(i.e. 0s instead of 1s). You’ll want to specify the ratio of 0s to 1s in the pos_weight argument. For example, if the entire training set has 90000 0s and 10000 1s, you’d set that equal to 90000/10000= 9.

  1. Anytime you’re choosing appropriate transforms, you will want to first familiarize yourself with the dataset, to see what it should look like. Then use matplotlib to plot the image after the transform is applied to see what each transform is doing. Is the transform bounded to fall within the standard deviation in the dataset for that particular transform?
from matplotlib import plyplot as plt

#single image no batch dim
plt.imshow(x.permute(1,2,0)

#or if it's a batch of images

for i in range(x.size(0):
    plt.imshow(x[i:i+1, ...].squeeze(0).permute(1,2,0)

i’m sorry, i’ve just started using pytorch several weeks ago so i little bit confused with your statement, would you explain more what the ratio and post weight mean?

It’s explained here: multilabel classification - How to calculate unbalanced weights for BCEWithLogitsLoss in pytorch - Stack Overflow

Suppose all of your ground truths for training were in 1 tensor y with dimensions [train_size, channels, height, width]. Then you’d just run:

y = torch.randint(0,2, (59, 1, 256, 256))

zeros = (y == 0).sum()
ones = (y == 1).sum()

pos_weight = zeros/ones

print(pos_weight)

This is necessary because otherwise your model could just guess all zeros and be mostly correct(since most of the ground truth is zero), but totally wrong. Setting the pos_weight helps weigh the ones and zeros evenly in the loss function.

assume my grondtruth image stored in train_y variable, so the code is like this?

zeros = (train_y == 0).sum()
ones = (train_y == 1).sum()

pos_weight = zeros/ones

print(pos_weight)

update : my pos_weight result

Pos_weight:  tensor(1818.2500)

That seems very high. Is that only one example(i.e. an outlier) or did you place all of the ground truths for training in train_y?

sorry i think i made a mistake for looping through not all the groundtruth, here is my new pos weight value

Pos_weight:  tensor(548.0924)

and here’s my code to calculate that

class ProcessDataset(Dataset):
        # doing transforms
        ...............................
        # Calculate pos_weight
        negatives = torch.sum(img_y == 0)
        positives = torch.sum(img_y == 1)
        pos_weight = negatives / positives

        return img_x, img_y, pos_weight

dataset = ProcessDataset(train_x, train_y)

for i in range(len(train_y)):
    image, mask, pos_weight = dataset[i]

print("Pos_weight: ", pos_weight)