Loss function in Image segmentation (nn.BCEWithLogitsLoss)

Hi I’ve been struggling so long time doing Image segmentation.
specifically I want to get predicted brain cancer mask out of brain MRIs and Brain cancer masks image file.
brain MRI has shape of 1x3x256x256(RGB) and mask has shape of 1x1x256x256(Black and white).

At first I tried to use nn.CrossEntropyLoss(), It failed maybe because of my poor understanding of dimensions.

I wanted to get solutions from here, but couldn’t really understand the meaning of targets getting size of [batch_size, H, W] containing the class indices. Does this mean actually the target shape isn’t [batch_size, H, W] and instead of it, Its something like [batch_size, indices, H, W] ?
here’s what I got from @ptrblck.

anyways, I just found another loss function, which is nn.BCEWithLogitsLoss()
It looked easier than using nn.CrossEntropyLoss because it just needed the Inputs and Targets with same shape. but using this made another problem.
Here’s my full code below

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import os, sys, random, time
import matplotlib.pyplot as plt
from PIL import Image
import re
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)
class BasicDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.mriids = next(os.walk(self.imgs_dir))[2]
        self.maskids = next(os.walk(self.masks_dir))[2]
        
        def atoi(text):
            return int(text) if text.isdigit() else text

        def natural_keys(text):
            return [atoi(c) for c in re.split(r'(\d+)', text) ] 
        
        self.mriids = sorted(self.mriids, key = natural_keys)
        self.maskids = sorted(self.maskids, key = natural_keys)

    def __len__(self):
        return len(self.mriids)

    def __getitem__(self, idx):

        mriidx = self.mriids[idx] #img file name
        maskidx = self.maskids[idx] #mask file name
        
        mask_file = os.path.join(self.masks_dir, maskidx)
        img_file = os.path.join(self.imgs_dir, mriidx)
        
        img = Image.open(img_file).convert("RGB")
        mask = Image.open(mask_file).convert("L")
        
        mask = np.array(mask)
        img = np.array(img)

        mask = np.expand_dims(mask, axis=2)
        
        img = np.transpose(img, (2, 0, 1))
        mask = np.transpose(mask, (2, 0, 1))

        obj_ids = np.unique(mask)
        obj_ids = obj_ids[1:]

        num_objs = len(obj_ids)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        mask = torch.as_tensor(mask, dtype=torch.uint8)

        image_id = torch.tensor([idx])

        target = {}
        target["labels"] = labels
        target["masks"] = mask
        target["image_id"] = image_id
        
        return img, target
#model
class unet(nn.Module):
    def __init__(self):
        super(unet, self).__init__()
        self.encoder1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size = 3, stride = 1, padding = 1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.encoder4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        self.conv_mid = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),   
                                     nn.Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder4 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder3 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder2 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.decoder1 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
                                     nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                     nn.ReLU(inplace=True))
        
        self.upconv4 = nn.ConvTranspose2d(1024, 1024, kernel_size=(2, 2), stride=2)
        
        self.upconv3 = nn.ConvTranspose2d(512, 512, kernel_size=(2, 2), stride=(2, 2))
        
        self.upconv2 = nn.ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
        
        self.upconv1 = nn.ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2))
        
        self.conv1x1_out = nn.Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        out = self.encoder1(x)
        out = self.pool1(out)
        out = self.encoder2(out)
        out = self.pool2(out)
        out = self.encoder3(out)
        out = self.pool3(out)
        out = self.encoder4(out)
        out = self.pool4(out)
        out = self.conv_mid(out)
        out = self.upconv4(out)
        out = self.decoder4(out)
        out = self.upconv3(out)
        out = self.decoder3(out)
        out = self.upconv2(out)
        out = self.decoder2(out)
        out = self.upconv1(out)
        out = self.decoder1(out)
        out = self.conv1x1_out(out)

        return out
    
model = unet()#.to(device)
epochs = 1000
batch_size = 1
lr = 0.00001
momentum = 0.99

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
loss_func = nn.BCEWithLogitsLoss()#.to(device)

gen = BasicDataset('/home/intern/Desktop/YH/Brain_MRI/BrainMRI_train/MRI/MRI/', '/home/intern/Desktop/YH/Brain_MRI/BrainMRI_train/mask/mask/')
train_loader = DataLoader(gen, batch_size=batch_size)

total_batch = len(gen)

model.train()
print("start training")
for epoch in range(epochs):
    
    for mri, true_mask in train_loader:

        mri = mri.type(torch.FloatTensor)
        true_mask = true_mask["masks"]
        true_mask = true_mask.type(torch.FloatTensor)
        
        mri = mri#.to(device)
        true_mask = true_mask#.to(device)
        
        t0 = time.time()
        avg_cost = 0.0
        
        pred_mask = model(mri)
        loss = loss_func(pred_mask, true_mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        avg_cost += loss / total_batch
        t1 = time.time()
    print('[Epoch:{}], cost = {}, time = {}'.format(epoch+1, avg_cost, t1-t0))
print('training Finished!')

out

start training

when I run this code, I don’t see any progress from this ‘start training’. and even interupt kernel doesn’t work.
I’m not using GPU right now. when I move the model, inputs, targets to gpu, some kind of CUDA error occurs.

Hello Yoonho -

Let me try to answer part of your question.

(I haven’t looked at your dataloader or model code, so I won’t comment
on it.)

Let me try to give you some general context here.

The shape of your input doesn’t really matter to your loss function – it
just has to match what your model is expecting for its input.

But I interpret this as meaning you have an input image of 256x256
pixels with three (RGB color) channels. I assume that the “1”, means
that you have a batch size of nBatch = 1. (There is nothing necessarily
wrong with using a batch size of 1, but you might get better throughput
and / or training with a larger batch size.)

So a batch of input samples will be a tensor of shape
(nBatch, 3, 256, 256).

Next, I assume that you are performing binary image segmentation.
That is, you want to label each pixel in your image with a “0” (this
pixel is not part of a cancerous tumor) or “1” (this pixel is cancerous).

Similarly I assume that your “black-and-white” image masks – the
“targets” that you use for your training – are also per-pixel "0"s and
"1"s. So a batch of targets will have shape (nBatch, 256, 256)
(and each value of such a tensor will be 0 or 1).

Lastly, for a single input sample, the output of your network should
be a 256x256-pixel “image,” so that when applied to a batch of
input samples, your network will produce output tensors of shape
(nBatch, 256, 256). Each “pixel” of your output tensor (your
model’s prediction) will be a “raw-score” logit that runs from -inf
to +inf. As you can see your model’s predictions that get fed into
your BCEWithLogits and your targets that also get fed into your
loss function are indeed of the same shape.

Based on my understanding of your use case (black-and-white image
masks), BCEWithLogitsLoss (rather than CrossEntropyLoss) will be
the appropriate loss function.

No. Neither for BCEWithLogitsLoss nor CrossEntropyLoss will
your target have shape (nBatch, nClass, 256, 256). It means
that each pixel (in your batch of shape (nBatch, 256, 256)) will
have a value equal to the “class index.” For BCEWithLogitsLoss
this value will be 0 or 1 (and in [0, nClass - 1], inclusive, for
CrossEntropyLoss). (If you were to use CrossEntropyLoss,
however, your prediction would have a class dimension, that is,
would have shape (nBatch, nClass, 256, 256).)

Not only might it be easier, it is the right loss function for your use case.

Yes, as discussed above, you will want both your predictions (inputs to
the loss function) and your targets to have shape (nBatch, 256, 256).

I can’t tell you whether your code is producing the right shapes and
values (or if i might have other issues), but check, at least, that your
predictions and targets are consistent with what I outlined above.

Best.

K. Frank

1 Like

Hi Frank,
Thanks for your thorough answer!
One thing I don’t understand is this:

You say that:

Each “pixel” of your output tensor (your
model’s prediction) will be a “raw-score” logit that runs from -inf
to +inf. As you can see your model’s predictions that get fed into
your BCEWithLogits and your targets that also get fed into your
loss function are indeed of the same shape.

So, I suppose that after the loss has been calculated, we do some backpropagation and optimizer step based on this loss, and this improves our model. But if our model outputs these “logits” - then how can we end with something that looks like our label image/mask (0s and 1s), so that we can actually plot the “image” output and see for ourselves how close it is to the labels?

Hope I’m making sense.

Sincerely, Hubert.

Well for that you apply the np.argmax function on the output of the network you will get an output in the same format as that of your labels and you can use opencv or pil to show that image after you have converted the cuda tensor to a numpy array on the cpu