Help with image semantic segmentation using U-Net

I am still new to PyTorch and would be really grateful for some help with this task, which I’m doing for learning purposes. I have previously implemented image classification with help from some tutorials, but I am getting stuck on how to transition to semantic segmentation, where each pixel gets a class label. My goal here is to use a U-Net for the semantic segmentation task.

So far I have code to load the Carvana dataset (https://www.kaggle.com/c/carvana-image-masking-challenge) and perform training of a U-Net (PyTorch model class code from https://github.com/shreyaspadhy/UNet-Zoo, original paper: https://arxiv.org/pdf/1505.04597.pdf), but I am getting issues with the loss function that I don’t understand how to debug (errors and code included below). Some notes about the data:

  • Each image and mask is size 128 x 128
  • Train and Test datasets are set up as PyTorch Dataset objects, where each element is a tuple of tensors (image and mask)

I understand that in semantic segmentation I should compute a pixel-wise loss rather than a loss for the whole image. Some questions I have:

  • Can I use normal CrossEntropyLoss() for semantic segmentation, or do I need to use something else?
  • How can I get pixel-wise predictions from the model and use those to compute the accuracy (the code in the ‘STEP 7: TRAIN THE MODEL’ below is from a classification task, and I’m not sure how to adapt it to segmentation)?
  • Can I follow the same general steps in the model training step as for classification, or is there something(s) I need to change?
  • How can I actually perform forward passes after this model is trained in order to collect masks as the output (would be a ‘STEP 8: USE THE MODEL’ I suppose, but I’m not sure how to approach this)

I’ve included my code and error message below. The code is organized into 7 steps (1 = load dataset, 2 = make dataset iterable, 3 = create model class, 4 = instantiate model class, 5 = instantiate loss class, 6 = instantiate optimizer class, 7 = train model):

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, transform, color
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from pathlib import Path
import json

'''
########################################## STEP 1: LOADING DATASET
'''

class CarvanaDataset(Dataset):
    
    def __init__(self, img_dir, mask_dir, train):
        """
        Args:
        root_dir (string): Directory with all the images
        train (boolean): Whether the dataset is training data (train = True, test = False)
        """
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.train = train
        self.image_list = []
        self.counter = 0 #dummy variable to break out of io loop for faster testing
        
        # Now iterate through all images in the directory to pull them into python

        for filename in os.listdir(img_dir):
            img = io.imread(os.path.join(img_dir, filename)) #this joins the paths of my root directory with each filename
            maskname = filename[:-4] + "_mask.png" #cuts off .jpg and adds '_mask.png'
            mask = io.imread(os.path.join(mask_dir, maskname))
            self.image_list.append((torch.tensor(color.rgb2gray(img)), torch.tensor(mask))) #converts car image to grayscale
            self.counter = self.counter + 1
            
            if (self.counter > 100):
                break
        
    def __len__(self):
        return len(self.image_list)
        
    def __getitem__(self, idx):
        
        return self.image_list[idx]

data_set = CarvanaDataset(img_dir = 'train-128/', mask_dir = 'train_masks-128/', train = True)

# Split the training data into train and test data
train_dataset = data_set[:int(len(data_set)/2)]
test_dataset = data_set[int(len(data_set)/2):]
plt.imshow(test_dataset[50][0]) #show an image

'''
########################################## STEP 2: MAKING DATASET ITERABLE
'''

batch_size = 5 #we will feed the model 100 images at a time
n_iters = 300
num_epochs = n_iters / (len(train_dataset) / batch_size) #need to review epochs and why this is the way to compute it
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True) #shuffle ensures we traverse images in different order across epochs

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                          batch_size = batch_size,
                                          shuffle = False) #we don't do shuffle here because we only do 1 forward pass

'''
########################################## STEP 3: CREATE MODEL CLASS
'''
class UNet(nn.Module):
    def __init__(self, num_channels=1, num_classes=2):
        super(UNet, self).__init__()
        num_feat = [64, 128, 256, 512, 1024]

        self.down1 = nn.Sequential(Conv3x3(num_channels, num_feat[0]))

        self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Conv3x3(num_feat[0], num_feat[1]))

        self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Conv3x3(num_feat[1], num_feat[2]))

        self.down4 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Conv3x3(num_feat[2], num_feat[3]))

        self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                    Conv3x3(num_feat[3], num_feat[4]))

        self.up1 = UpConcat(num_feat[4], num_feat[3])
        self.upconv1 = Conv3x3(num_feat[4], num_feat[3])

        self.up2 = UpConcat(num_feat[3], num_feat[2])
        self.upconv2 = Conv3x3(num_feat[3], num_feat[2])

        self.up3 = UpConcat(num_feat[2], num_feat[1])
        self.upconv3 = Conv3x3(num_feat[2], num_feat[1])

        self.up4 = UpConcat(num_feat[1], num_feat[0])
        self.upconv4 = Conv3x3(num_feat[1], num_feat[0])

        self.final = nn.Sequential(nn.Conv2d(num_feat[0],
                                             num_classes,
                                             kernel_size=1),
                                   nn.Softmax2d())

    def forward(self, inputs, return_features=False):
        # print(inputs.data.size())
        down1_feat = self.down1(inputs)
        # print(down1_feat.size())
        down2_feat = self.down2(down1_feat)
        # print(down2_feat.size())
        down3_feat = self.down3(down2_feat)
        # print(down3_feat.size())
        down4_feat = self.down4(down3_feat)
        # print(down4_feat.size())
        bottom_feat = self.bottom(down4_feat)

        # print(bottom_feat.size())
        up1_feat = self.up1(bottom_feat, down4_feat)
        # print(up1_feat.size())
        up1_feat = self.upconv1(up1_feat)
        # print(up1_feat.size())
        up2_feat = self.up2(up1_feat, down3_feat)
        # print(up2_feat.size())
        up2_feat = self.upconv2(up2_feat)
        # print(up2_feat.size())
        up3_feat = self.up3(up2_feat, down2_feat)
        # print(up3_feat.size())
        up3_feat = self.upconv3(up3_feat)
        # print(up3_feat.size())
        up4_feat = self.up4(up3_feat, down1_feat)
        # print(up4_feat.size())
        up4_feat = self.upconv4(up4_feat)
        # print(up4_feat.size())

        if return_features:
            outputs = up4_feat
        else:
            outputs = self.final(up4_feat)

        return outputs

class Conv3x3(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Conv3x3, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.BatchNorm2d(out_feat),
                                   nn.ReLU())

        self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.BatchNorm2d(out_feat),
                                   nn.ReLU())

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class Conv3x3Drop(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Conv3x3Drop, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.Dropout(p=0.2),
                                   nn.ReLU())

        self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.BatchNorm2d(out_feat),
                                   nn.ReLU())

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class Conv3x3Small(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Conv3x3Small, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.ELU(),
                                   nn.Dropout(p=0.2))

        self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.ELU())

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class UpConcat(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpConcat, self).__init__()

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # self.deconv = nn.ConvTranspose2d(in_feat, out_feat,
        #                                  kernel_size=3,
        #                                  stride=1,
        #                                  dilation=1)

        self.deconv = nn.ConvTranspose2d(in_feat,
                                         out_feat,
                                         kernel_size=2,
                                         stride=2)

    def forward(self, inputs, down_outputs):
        # TODO: Upsampling required after deconv?
        # outputs = self.up(inputs)
        outputs = self.deconv(inputs)
        out = torch.cat([down_outputs, outputs], 1)
        return out


class UpSample(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpSample, self).__init__()

        self.up = nn.Upsample(scale_factor=2, mode='nearest')

        self.deconv = nn.ConvTranspose2d(in_feat,
                                         out_feat,
                                         kernel_size=2,
                                         stride=2)

    def forward(self, inputs, down_outputs):
        # TODO: Upsampling required after deconv?
        outputs = self.up(inputs)
        # outputs = self.deconv(inputs)
        out = torch.cat([outputs, down_outputs], 1)
        return out

'''
########################################## STEP 4: INSTANTIATE MODEL CLASS
'''
model = UNet()

#####################
# USE GPU FOR MODEL #
#####################

if torch.cuda.is_available():
    model.cuda()

'''
########################################## STEP 5: INSTANTIATE LOSS CLASS
'''
criterion = nn.CrossEntropyLoss()

'''
########################################## STEP 6: INSTANTIATE OPTIMIZER CLASS
'''
learning_rate = 0.01 #note: 1 iteration is 100 images, which is our batch size. We update parameters every 100 images

optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

'''
########################################## STEP 7: TRAIN THE MODEL
'''
iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        #####################
        # USE GPU FOR MODEL #
        #####################
        if torch.cuda.is_available():
            images = images.unsqueeze(1).type(torch.FloatTensor).cuda()
            labels = labels.cuda()
        else:
            images = images.unsqueeze(1).type(torch.FloatTensor)
       
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad() #don't want to accumulate gradients from previous iterations. Why?
        
        # Forward pass to get ouput/logits
        # Size of outputs is 100 x 10 because each image has output of a value for each digit. Higher value = more likely.
        outputs = model(images)
        
        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
        
        # Getting gradients w.r.t. parameters
        loss.backward()
        
        # Updating parameters
        optimizer.step()
        
        iter += 1
        
        if (iter % 50 == 0):
            # Calculate Accuracy, for every 50 iterations
            correct = 0
            total = 0
            # Iterate through the test dataset
            for images, labels in test_loader:
                #####################
                # USE GPU FOR MODEL #
                #####################
                if torch.cuda.is_available():
                    images = images.unsqueeze(1).type(torch.FloatTensor).cuda() #deals with dimension issue

                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1) #need to review how this syntax works
                
                # Total number of lables
                total += labels.size(0)
                
                #####################
                # USE GPU FOR MODEL #
                #####################
                # Total correct predictions... need to bring predicted back to cpu to be able to use .sum() python function
                if torch.cuda.is_available():
                    correct += (predicted.cpu() == labels.cpu()).sum().item()
                else:
                    correct += (predicted == labels).sum().item()
                    
            accuracy = 100 * (correct / total)
            
            # Print Loss
            print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.item(), accuracy))

''' ERRORS
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-38-2baa99b3f165> in <module>
     22 
     23         # Calculate Loss: softmax --> cross entropy loss
---> 24         loss = criterion(outputs, labels)
     25 
     26         # Getting gradients w.r.t. parameters

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

~\Anaconda3\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
    902     def forward(self, input, target):
    903         return F.cross_entropy(input, target, weight=self.weight,
--> 904                                ignore_index=self.ignore_index, reduction=self.reduction)
    905 
    906 

~\Anaconda3\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   1968     if size_average is not None or reduce is not None:
   1969         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1970     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   1971 
   1972 

~\Anaconda3\lib\site-packages\torch\nn\functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1790         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1791     elif dim == 4:
-> 1792         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   1793     else:
   1794         # dim == 3 or dim > 4

RuntimeError: Expected object of scalar type Long but got scalar type Byte for argument #2 'target'
'''

Thank you for your help!

Hi Joseph,

As the error message says, the reason is the target tensor is ByteTensor though it’s expected to be LongTensor. So, specify the target to be Long. There exist some solutions, but for example, in the dataset,

self.image_list.append((torch.tensor(color.rgb2gray(img)), torch.tensor(mask, dtype=torch.long))) #converts car image to grayscale

Hi, you should change your target to long, just change the loss to code below:

loss = criterion(outputs, labels.long())

Thanks, that solved this error. Just so I can debug more easily in the future, how did you know that the labels are the target that I needed to change the datatype for? I wasn’t sure how to interpret which target was meant by the error of: “RuntimeError: Expected object of scalar type Long but got scalar type Byte for argument #2 ‘target’”. I saw it was the target in “ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)” but I still didn’t know what target that meant.

Also, what determines the initial datatype of the tensors? I didn’t specify Byte anywhere in my code when initializing them.

Thank you!

Thanks for your solution, that makes sense. Is Byte tensor the default when converting to grayscale, or was this because of my original image format? Also, does target always refer to the labels in the error code, or can it refer to something else? I was initially confused about what was meant by target and which tensor’s type to change.

Thank you!