tuSimple ENet lane binary segmentation to multi-channel output segmentation

Hello everyone,

I am new to computer vision and I am trying to implement ENet semantic segmentation for lane detection.

I currently have a working version of binary segmentation, achieving mIoU of ~77.

However, the provided TuSimple benchmark evaluation requires separate lanes to be identified.

Below is my dataloader,

class tuSimpleDataset(data.Dataset): 
    def __init__(self, file_path=root, size=[640,368], split='train', gray=True, intensity=0):
        self.width = size[0]
        self.height = size[1]
        self.n_seg = 5 # total number of possible lanes (lanes <= 5)
        self.size = size
        
        if split == 'train':
            self.file_path = train
        elif split == 'test':
            self.file_path = test
        
     
        self.flags = {'size': size, 'gray': gray, 'split': split, 'intensity': intensity}
       
        self.json_lists = glob.glob(os.path.join(self.file_path, '*.json'))
        self.labels = []
        for json_list in self.json_lists:
            self.labels += [json.loads(line) for line in open(json_list)]
        self.lanes = [label['lanes'] for label in self.labels]
        self.y_samples = [label['h_samples'] for label in self.labels]
        self.raw_files = [label['raw_file'] for label in self.labels]
        
        self.img = np.zeros(size, np.uint8)
        self.label_img = np.zeros(size,np.uint8)
        self.ins_img = np.zeros((0, size[0], size[1]), np.uint8)
        self.len = len(self.labels)
            
    def preprocess(self):
        # CLAHE normalization
        img = cv2.cvtColor(self.img, cv2.COLOR_RGB2LAB)
        img_plane = cv2.split(img)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img_plane[0] = clahe.apply(img_plane[0])
        img = cv2.merge(img_plane)
        self.img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
    
    def image_resize(self):
        ins = []
        self.img = cv2.resize(self.img, tuple(self.flags['size']), interpolation=cv2.INTER_CUBIC)
        self.label_img = cv2.resize(self.label_img, tuple(self.flags['size']), interpolation=cv2.INTER_CUBIC)
        for i in range(len(self.ins_img)): 
            dst = cv2.resize(self.ins_img[i], tuple(self.flags['size']), interpolation=cv2.INTER_CUBIC)
            ins.append(dst)
        self.ins_img = np.array(ins, dtype=np.uint8)
    
    def get_lane_image(self, idx):
        lane_pts = [[(x,y) for (x,y) in zip(lane, self.y_samples[idx]) if x >= 0] for lane in self.lanes[idx]]
        while len(lane_pts) < self.n_seg:
            lane_pts.append(list())
        self.img = plt.imread(os.path.join(self.file_path, self.raw_files[idx]))
        self.height, self.width, _ = self.img.shape
        self.label_img = np.zeros((self.height, self.width), dtype=np.uint8)
        self.ins_img = np.zeros((0, self.height, self.width), dtype=np.uint8)
        
        for i, lane_pt in enumerate(lane_pts):
            cv2.polylines(self.label_img, np.int32([lane_pt]), isClosed=False, color=(1), thickness=15)
            gt = np.zeros((self.height, self.width), dtype=np.uint8)
            gt = cv2.polylines(gt, np.int32([lane_pt]), isClosed=False, color=(1), thickness=7)
            self.ins_img = np.concatenate([self.ins_img, gt[np.newaxis]])

    
    
    def __getitem__(self, idx):
        self.get_lane_image(idx)
        self.image_resize()
        self.preprocess()

        if self.flags['split']=='train':
            #self.random_transform() # disable random_transform
            # Below changes shape from [368,640,3] to [3, 368, 640]
            # self.img = np.array(np.transpose(self.img, (2,0,1)), dtype=np.float32)
            self.label_img = np.array(self.label_img, dtype=np.float32)
            self.ins_img = np.array(self.ins_img, dtype=np.float32)
            return torch.Tensor(self.img), torch.LongTensor(self.label_img), torch.Tensor(self.ins_img)
        else:
            self.img = np.array(np.transpose(self.img, (2,0,1)), dtype=np.float32)
            return torch.Tensor(self.img)
    
    
    def __len__(self):
        return self.len

In order to achieve distinguishable lanes as prediction output, am I right to say that I should increase the output channels of ENet from 2 to 6 (5 possible lanes + background)?

In get_lane_images(self,idx), the current code draws a polyline for each lane points.

My question is, how may I be able to provide an input image with their pixel values equal to their class value and if this is this the method I should use such that I may achieve my requirements?

i.e. in an example of a 7x7 image input,

[[ 0 5 0 0 0 0 0 ]
[ 0 5 0 0 0 3 0 ]
[ 0 0 0 0 0 3 0 ]
[ 0 1 0 0 0 3 0 ]
[ 0 1 0 2 0 0 4 ]
[ 0 1 0 2 0 0 4 ]
[ 1 0 0 2 0 0 4 ]]

the values of my image input of 368x640 would be the respective classes,

0 for background,
1 for lane 1,
2 for lane 2 and so on

The lanes do not have to be specific (i.e. class 1 does not have to be the leftmost lane, it could be any lane) and if only 3 lanes exist, class 4 and 5 would not be included in the input label pixel value. I am also wondering if this would affect my training and accuracy compared to a binary segmentation.

In the picture, the first 2 columns were my initial input (input image and binary label image). I hope to be able to identify different lanes in my output. The last column is empty since there are only 4 lanes in this input example.

The ENet model that I am currently using is below, I believe I would have to change some of it apart from just the output channels in order to have distinct lanes in the output like the image below?

image

"""Efficient Neural Network"""
import torch
import torch.nn as nn
import torch.nn.functional as F

from .segbase import SegBaseModel
from .model_zoo import MODEL_REGISTRY
from ..modules import  _FCNHead
from ..config import cfg

__all__ = ['ENet']


@MODEL_REGISTRY.register()
class ENet(SegBaseModel):
    """Efficient Neural Network"""

    def __init__(self, **kwargs):
        super(ENet, self).__init__(need_backbone=False)
        self.initial = InitialBlock(13, **kwargs)

        self.bottleneck1_0 = Bottleneck(16, 16, 64, downsampling=True, **kwargs)
        self.bottleneck1_1 = Bottleneck(64, 16, 64, **kwargs)
        self.bottleneck1_2 = Bottleneck(64, 16, 64, **kwargs)
        self.bottleneck1_3 = Bottleneck(64, 16, 64, **kwargs)
        self.bottleneck1_4 = Bottleneck(64, 16, 64, **kwargs)

        self.bottleneck2_0 = Bottleneck(64, 32, 128, downsampling=True, **kwargs)
        self.bottleneck2_1 = Bottleneck(128, 32, 128, **kwargs)
        self.bottleneck2_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
        self.bottleneck2_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
        self.bottleneck2_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
        self.bottleneck2_5 = Bottleneck(128, 32, 128, **kwargs)
        self.bottleneck2_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
        self.bottleneck2_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
        self.bottleneck2_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)

        self.bottleneck3_1 = Bottleneck(128, 32, 128, **kwargs)
        self.bottleneck3_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
        self.bottleneck3_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
        self.bottleneck3_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
        self.bottleneck3_5 = Bottleneck(128, 32, 128, **kwargs)
        self.bottleneck3_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
        self.bottleneck3_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
        self.bottleneck3_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)

        self.bottleneck4_0 = UpsamplingBottleneck(128, 16, 64, **kwargs)
        self.bottleneck4_1 = Bottleneck(64, 16, 64, **kwargs)
        self.bottleneck4_2 = Bottleneck(64, 16, 64, **kwargs)

        self.bottleneck5_0 = UpsamplingBottleneck(64, 4, 16, **kwargs)
        self.bottleneck5_1 = Bottleneck(16, 4, 16, **kwargs)

        self.fullconv = nn.ConvTranspose2d(16, self.nclass, 2, 2, bias=False)

        self.__setattr__('decoder', ['bottleneck1_0', 'bottleneck1_1', 'bottleneck1_2', 'bottleneck1_3',
                                       'bottleneck1_4', 'bottleneck2_0', 'bottleneck2_1', 'bottleneck2_2',
                                       'bottleneck2_3', 'bottleneck2_4', 'bottleneck2_5', 'bottleneck2_6',
                                       'bottleneck2_7', 'bottleneck2_8', 'bottleneck3_1', 'bottleneck3_2',
                                       'bottleneck3_3', 'bottleneck3_4', 'bottleneck3_5', 'bottleneck3_6',
                                       'bottleneck3_7', 'bottleneck3_8', 'bottleneck4_0', 'bottleneck4_1',
                                       'bottleneck4_2', 'bottleneck5_0', 'bottleneck5_1', 'fullconv'])

    def forward(self, x):
        # init
        x = self.initial(x)

        # stage 1
        x, max_indices1 = self.bottleneck1_0(x)
        x = self.bottleneck1_1(x)
        x = self.bottleneck1_2(x)
        x = self.bottleneck1_3(x)
        x = self.bottleneck1_4(x)

        # stage 2
        x, max_indices2 = self.bottleneck2_0(x)
        x = self.bottleneck2_1(x)
        x = self.bottleneck2_2(x)
        x = self.bottleneck2_3(x)
        x = self.bottleneck2_4(x)
        x = self.bottleneck2_5(x)
        x = self.bottleneck2_6(x)
        x = self.bottleneck2_7(x)
        x = self.bottleneck2_8(x)

        # stage 3
        x = self.bottleneck3_1(x)
        x = self.bottleneck3_2(x)
        x = self.bottleneck3_3(x)
        x = self.bottleneck3_4(x)
        x = self.bottleneck3_6(x)
        x = self.bottleneck3_7(x)
        x = self.bottleneck3_8(x)

        # stage 4
        x = self.bottleneck4_0(x, max_indices2)
        x = self.bottleneck4_1(x)
        x = self.bottleneck4_2(x)

        # stage 5
        x = self.bottleneck5_0(x, max_indices1)
        x = self.bottleneck5_1(x)

        # out
        x = self.fullconv(x)
        return x


class InitialBlock(nn.Module):
    """ENet initial block"""

    def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
        super(InitialBlock, self).__init__()
        self.conv = nn.Conv2d(3, out_channels, 3, 2, 1, bias=False)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.bn = norm_layer(out_channels + 3)
        self.act = nn.PReLU()

    def forward(self, x):
        x_conv = self.conv(x)
        x_pool = self.maxpool(x)
        x = torch.cat([x_conv, x_pool], dim=1)
        x = self.bn(x)
        x = self.act(x)
        return x


class Bottleneck(nn.Module):
    """Bottlenecks include regular, asymmetric, downsampling, dilated"""

    def __init__(self, in_channels, inter_channels, out_channels, dilation=1, asymmetric=False,
                 downsampling=False, norm_layer=nn.BatchNorm2d, **kwargs):
        super(Bottleneck, self).__init__()
        self.downsamping = downsampling
        if downsampling:
            self.maxpool = nn.MaxPool2d(2, 2, return_indices=True)
            self.conv_down = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                norm_layer(out_channels)
            )

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 1, bias=False),
            norm_layer(inter_channels),
            nn.PReLU()
        )

        if downsampling:
            self.conv2 = nn.Sequential(
                nn.Conv2d(inter_channels, inter_channels, 2, stride=2, bias=False),
                norm_layer(inter_channels),
                nn.PReLU()
            )
        else:
            if asymmetric:
                self.conv2 = nn.Sequential(
                    nn.Conv2d(inter_channels, inter_channels, (5, 1), padding=(2, 0), bias=False),
                    nn.Conv2d(inter_channels, inter_channels, (1, 5), padding=(0, 2), bias=False),
                    norm_layer(inter_channels),
                    nn.PReLU()
                )
            else:
                self.conv2 = nn.Sequential(
                    nn.Conv2d(inter_channels, inter_channels, 3, dilation=dilation, padding=dilation, bias=False),
                    norm_layer(inter_channels),
                    nn.PReLU()
                )
        self.conv3 = nn.Sequential(
            nn.Conv2d(inter_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.Dropout2d(0.1)
        )
        self.act = nn.PReLU()

    def forward(self, x):
        identity = x
        if self.downsamping:
            identity, max_indices = self.maxpool(identity)
            identity = self.conv_down(identity)

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.act(out + identity)

        if self.downsamping:
            return out, max_indices
        else:
            return out


class UpsamplingBottleneck(nn.Module):
    """upsampling Block"""

    def __init__(self, in_channels, inter_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
        super(UpsamplingBottleneck, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            norm_layer(out_channels)
        )
        self.upsampling = nn.MaxUnpool2d(2)

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, inter_channels, 1, bias=False),
            norm_layer(inter_channels),
            nn.PReLU(),
            nn.ConvTranspose2d(inter_channels, inter_channels, 2, 2, bias=False),
            norm_layer(inter_channels),
            nn.PReLU(),
            nn.Conv2d(inter_channels, out_channels, 1, bias=False),
            norm_layer(out_channels),
            nn.Dropout2d(0.1)
        )
        self.act = nn.PReLU()

    def forward(self, x, max_indices):
        out_up = self.conv(x)
        out_up = self.upsampling(out_up, max_indices)

        out_ext = self.block(x)
        out = self.act(out_up + out_ext)
        return out

I hope that I am not too confusing, and I apologise for the multiple edits. Thank you all.

Your idea sounds generally right.
If I understand the use case correctly, your model is already able to predict the lanes, but cannot output each lane separately?
If that’s the case, would a post-processing step be OK for your use case or does it violate the benchmark evaluation?
You could e.g. use OpenCV to detect all lanes via a connected components analysis and add a tag to each line.

If that’s not possible, you could use a multi-label approach with 6 classes.
For this use case, your target should contain 6 channels, where each channel would be a “class” (in your case a lane), containing 1s for the current line.
I’m not sure how well this multi-label classification approach would work for the same object, but might be easy to try out using your current model and training.

2 Likes

Thank you very much for your insight @ptrblck.

Yes, you are exactly right, I am able to predict the lanes but not of each lane separately.
A post processing step should be OK for my benchmark evaluation. I did not think about that! I think I will do that first.

However I would still like to use the multi-label approach as well to see if there are any differences in performance. I have actually tried passing in a [16, 5, 368, 720] tensor into my model but received an error.

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [16, 5, 368, 640]

I believe I have to change my loss criterion?

class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
        super(MixSoftmaxCrossEntropyLoss, self).__init__(ignore_index=ignore_index)
        self.aux = aux
        self.aux_weight = aux_weight

    def _aux_forward(self, *inputs, **kwargs):
        *preds, target = tuple(inputs)

        loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[0], target)
        for i in range(1, len(preds)):
            aux_loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[i], target)
            loss += self.aux_weight * aux_loss
        return loss

    def _multiple_forward(self, *inputs):
        *preds, target = tuple(inputs)
        loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[0], target)
        for i in range(1, len(preds)):
            loss += super(MixSoftmaxCrossEntropyLoss, self).forward(preds[i], target)
        return loss

    def forward(self, *inputs, **kwargs):
        preds, target = tuple(inputs)
        inputs = tuple(list(preds) + [target])
        if self.aux:
            return dict(loss=self._aux_forward(*inputs))
        elif len(preds) > 1:
            return dict(loss=self._multiple_forward(*inputs))
        else:
            return dict(loss=super(MixSoftmaxCrossEntropyLoss, self).forward(*inputs))

The function that got called for the error is loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[0], target[0]) as above.

I hope I have given enough information such that it is not confusing.
Thank you again for your help, I really appreciate it.

You could try to use the multi-label approach with nn.BCEWithLogitsLoss, where your model output and target would have the shape [batch_size, nb_classes, height, width]. The target should have values in the range [0, 1], where a 1 in the nb_classes dimension denotes that this class is active in the current pixel location. This would also allow you to predict overlapping lanes.

Alternatively, a multi-class segmentation might also work, where you would use nn.CrossEntropyLoss with an output of [bathc_size, nb_classes, height, width]. The target would have the shape [batch_size, height, width] and contain values in the range [0, nb_classes-1].

I’m not sure which approach would work best.
I think the post-processing of your current might work quite well, as your use case is not really a multi-class or multi-label approach.

1 Like

Thanks for your explanation.

But in a case where the stacked one-hot encoded targets has overlapping and need to use the approach of multiclass, how would one handle the target since the index of the first 1 in corresponding pixel values of the stacked target will be given when taking its argmax? For example

torch.manual_seed(19)
outputs = torch.randint(low=0, high=2, size=(1,3,5,5), dtype=torch.float32)

# overlapping pixel value
for i in range(labelA.shape[1]):
    print(labelA[:, i, 2, :]) #[0,0,2,5]

###
print(torch.argmax(labelA, 1)) # 1,1,5

loss = nn.CrossEntropyLoss()
loss_value = loss(outputs, torch.argmax(labelA, 1))
print(loss_value)

For a multi-label classification/segmentation you cannot use torch.argmax, but could use a threshold to get the predictions for each sample/pixel.
However, this also won’t work with nn.CrossEntropyLoss, which is used for a multi-class classfication/segmentation, so use nn.BCEWithLogitsLoss instead.

1 Like

When using the BCEWithLogitsLoss I don’t torch.argmax since it takes directly the stacked one-hot encoded targets with 1 representing the presence and 0 representing the absence of each class label in the stack.

I’m a bit confused as regard thresholding the logits (the network final layer’s outputs with activation function) to get the predictions, are the predictions (either torch.softmax(logits) >> 0.5 or logits >= .5) not meant to be passed the evaluation metric e.g dice_metric(predictions, stacked_one-hot_encoded_targets) and the logits passed to the loss function e.g BCEWithLogitsLoss(logits, stacked_one-hot_encoded_targets)?

To get the predictions you could either directly threshold the logits or apply sigmoid on them and then use a threshold in [0, 1]. E.g. a threshold of 0.0 applied on the logits would correspond to 0.5 after applying the sigmoid.

Predictions are usually not wanted in loss functions, as the threshold op or argmax are not differentiable.

1 Like

Thanks for the explanation.

I quite understand the reason for applying sigmoid to the logits and then threshold of 0.5 as predictions for evaluation in the case of single label multiclass segmentation, is this applicable in the case of multilabel multiclass segmentation?

In addition, how do I get the metric per class of the stacked one-hot_encoded_targets?

Yes, it would also be applicable for multi-label classification tasks and would yield the predicted classes for each sample.

scikit-learn has some multi-label metrics, which could be useful.

1 Like