One of the class merged with Ignore class while training semantic segmentation

While training semantic segmentation i had 23 classes including Ignore label, during my prediction i get only 22 classes. I have used cross entropy loss and provided ignore index as 255.

What looks from the predicted image is one of the class (class_name:vegetation, id:21) merge with ignore index as seen in the below image.

Below image shows the Unique values for Ground truth and Prediction along with Original, GT & Predicted image.

Please help with what may have caused this!!!

What is the shape of your output activation?
In a segmentation use case, your output activation will most likely have the shape [batch_size, nb_classes, height, width].
If nb_classes was defined as 22, you will only get these labels.

1 Like

Thanks for replying @ptrblck , I am using deeplab v3+ architecture and my segmentation output shape is in [batch_size, nb_classes, height, width] format. I just printed out the unique values to check whether ignore index is their or not.

During training output tensor is of shape [batch, 23, 512, 512] after that i calculated the labels using _, pred = torch.max(scores, dim=1) .List of unique values of pred contain 22 (which is ignore index),

But at inference time when i use the trained weights, output is of shape [batch, 23, 512, 512] and the the unique values contains max label upto 21.

It seems that the model output contains the background class + all other 22 classes.
If I understand your work flow correctly, you have passed the background class index as ignore_index, so that your model never learned how to predict the background?

If that’s the case, I would expect the model never to output the background class during inference either.

1 Like

FYI i am using appoloscapes dataset, below are the labels available and in the trainID column we don’t have any classes for background. So i am using ignore index + 22 other classes.

For the architecture, loss i am passing below argument:

model = DeepLab (num_classes=23,backbone=resnet101,output_stride=16)
criterion = nn.CrossEntropyLoss(ignore_index=255)

labels = [
    #     name                    clsId    id   trainId   category  catId  hasInstanceignoreInEval   color
    Label('others'              ,    0 ,    0,   255   , '其他'    ,   0  ,False , True  , 0x000000 ),
    Label('rover'               , 0x01 ,    1,   255   , '其他'    ,   0  ,False , True  , 0X000000 ),
    Label('sky'                 , 0x11 ,   17,    0    , '天空'    ,   1  ,False , False , 0x4682B4 ),
    Label('car'                 , 0x21 ,   33,    1    , '移动物体',   2  ,True  , False , 0x00008E ),
    Label('car_groups'          , 0xA1 ,  161,    1    , '移动物体',   2  ,True  , False , 0x00008E ),  
    Label('motorbicycle'        , 0x22 ,   34,    2    , '移动物体',   2  ,True  , False , 0x0000E6 ),
    Label('motorbicycle_group'  , 0xA2 ,  162,    2    , '移动物体',   2  ,True  , False , 0x0000E6 ),
    Label('bicycle'             , 0x23 ,   35,    3    , '移动物体',   2  ,True  , False , 0x770B20 ),
    Label('bicycle_group'       , 0xA3 ,  163,    3    , '移动物体',   2  ,True  , False , 0x770B20 ),
    Label('person'              , 0x24 ,   36,    4    , '移动物体',   2  ,True  , False , 0x0080c0 ),
    Label('person_group'        , 0xA4 ,  164,    4    , '移动物体',   2  ,True  , False , 0x0080c0 ),
    Label('rider'               , 0x25 ,   37,    5    , '移动物体',   2  ,True  , False , 0x804080 ),
    Label('rider_group'         , 0xA5 ,  165,    5    , '移动物体',   2  ,True  , False , 0x804080 ),
    Label('truck'               , 0x26 ,   38,    6    , '移动物体',   2  ,True  , False , 0x8000c0 ),
    Label('truck_group'         , 0xA6 ,  166,    6    , '移动物体',   2  ,True  , False , 0x8000c0 ), 
    Label('bus'                 , 0x27 ,   39,    7    , '移动物体',   2  ,True  , False , 0xc00040 ),
    Label('bus_group'           , 0xA7 ,  167,    7    , '移动物体',   2  ,True  , False , 0xc00040 ),
    Label('tricycle'            , 0x28 ,   40,    8    , '移动物体',   2  ,True  , False , 0x8080c0 ),
    Label('tricycle_group'      , 0xA8 ,  168,    8    , '移动物体',   2  ,True  , False , 0x8080c0 ),
    Label('road'                , 0x31 ,   49,    9    , '平面'    ,   3  ,False , False , 0xc080c0 ),
    Label('siderwalk'           , 0x32 ,   50,    10   , '平面'    ,   3  ,False , False , 0xc08040 ),
    Label('traffic_cone'        , 0x41 ,   65,    11   , '路间障碍',   4  ,False , False , 0x000040 ),
    Label('road_pile'           , 0x42 ,   66,    12   , '路间障碍',   4  ,False , False , 0x0000c0 ),
    Label('fence'               , 0x43 ,   67,    13   , '路间障碍',   4  ,False , False , 0x404080 ),
    Label('traffic_light'       , 0x51 ,   81,    14   , '路边物体',   5  ,False , False , 0xc04080 ),
    Label('pole'                , 0x52 ,   82,    15   , '路边物体',   5  ,False , False , 0xc08080 ),
    Label('traffic_sign'        , 0x53 ,   83,    16   , '路边物体',   5  ,False , False , 0x004040 ),
    Label('wall'                , 0x54 ,   84,    17   , '路边物体',   5  ,False , False , 0xc0c080 ),
    Label('dustbin'             , 0x55 ,   85,    18   , '路边物体',   5  ,False , False , 0x4000c0 ),
    Label('billboard'           , 0x56 ,   86,    19   , '路边物体',   5  ,False , False , 0xc000c0 ),
    Label('building'            , 0x61 ,   97,    20   , '建筑'    ,   6  ,False , False , 0xc00080 ),
    Label('bridge'              , 0x62 ,   98,    255  , '建筑'    ,   6  ,False , True  , 0x808000 ),
    Label('tunnel'              , 0x63 ,   99,    255  , '建筑'    ,   6  ,False , True  , 0x800000 ),
    Label('overpass'            , 0x64 ,  100,    255  , '建筑'    ,   6  ,False , True  , 0x408040 ),
    Label('vegatation'          , 0x71 ,  113,    21   , '自然'    ,   7  ,False , False , 0x808040 ),
    Label('unlabeled'           , 0xFF ,  255,    255  , '未标注'  ,   8  ,False , True  , 0xFFFFFF ),

Thanks for the information.

Could you check, what the white color represents in your ground truth image?
I would assume it’s class0 (sky), but that wouldn’t make sense, as part of the buildings and car is also marked as this color.

what you told is correct ,actually all the white color in the image is ignore index. Sky, building and host vehicle all should have corresponding classes but the dataset provided by them labels some parts as 255.
Is their any workaround?

This would explain, why your model predicts random classes for these pixel locations.
Basically, your model will not get any information for these locations, as they are ignored, and can just output any random class. In your case it seems the vegetation class was picked.

You could transform all 255 class values to an “unknown” or background class and let the model learn this additional class instead.

1 Like

It worked, i trained my model for 10 epoch by adding all the ignore labels to extra class(class id:22). I have attached the results i got after 10 epochs.

Thanks @ptrblck for your support, i have closed the issue with your last answer.