Segmentation of images with multichannel masks

Hi, I am working on Segmentation of Camvid dataset. Masks are of shape [h,w,3]. Indeed for each class instead of the label of a unique class we have a color_code. I found a code for segmentation of this dataset which in order to use CrossEntropyLoss, a snippet was used that convert each mask from shape [h,w,3] to [32,h,w], where 32 is the number of classes. There is also another snippet that converts a mask of shape [32,h,w] to a 3-channel mask of shape [h,w,3]. Now, when I finished training and visualize predicted mask, the result is not good at all. I am exploring where I did a mistake. First I checked the two snippet for converting a RGB mask to a Binary mask. The codes are below.

def Color_map(df):
    
    '''
    Returns the reversed String.
    Parameters:
        dataframe: A Dataframe with rgb values with class maps.
    Returns:
        code2id: A dictionary with color as keys and class id as values.   
        id2code: A dictionary with class id as keys and color as values.
        name2id: A dictionary with class name as keys and class id as values.
        id2name: A dictionary with class id as keys and class name as values.
    '''
    cls = pd.read_csv(df)
    
    # thic line of code tuples the code of colors
    # len(cls.name) is the number of classes that we have: 32
    # output: [(64, 128, 64), (192, 0, 128)]
    color_code = [tuple(cls.drop("name",axis=1).loc[idx]) for idx in range(len(cls.name))]
    
    # it gives a number to each code
    # assigns color codes to id numbers
    code2id = {v: k for k, v in enumerate(list(color_code))}
    
    # it assigns numbers(classes) to codes
    id2code = {k: v for k, v in enumerate(list(color_code))}
    
    # it collects name of each class 
    color_name = [cls['name'][idx] for idx in range(len(cls.name))]
    
    # it gives to each class a number
    name2id = {v: k for k, v in enumerate(list(color_name))}
    
    # it gives
    id2name = {k: v for k, v in enumerate(list(color_name))}  
    
    return(code2id, id2code, name2id, id2name)

def mask_to_rgb(mask, id2code):
    ''' 
        Converts a Binary Mask of shape: [batch_size,num_classes,h,w] 
        to RGB image mask of shape [batch_size, h, w, color_code]
        
        Parameters:
            img: A Binary mask
            color_map: Dictionary representing color mappings
        returns:
            out: A RGB mask of shape [batch_size, h, w, color_code]
    '''
    ## Since our mask is one-hot encoding
    ## the argmax returns the output class for each pixel
    ## It returns the label of each pixel that is a number in range : 0-31
    ## dim 0 :batch_size
    
    single_layer = np.argmax(mask, axis=1)
    
    ## it converts each mask to [batch_size, h,w, color_code]
    output = np.zeros((mask.shape[0],mask.shape[2],mask.shape[3],3))
    
    for k in id2code.keys():
        
        output[single_layer==k] = id2code[k]
        
    return(output.astype(np.float32))

def rgb_to_mask(img, id2code):
    ''' 
        Converts a RGB image mask of shape [batch_size,h, w, color_code], to a mask of shape
        [batch_size,n_classes,h,w]
        
        Parameters:
            img: A RGB img mask
            color_map: Dictionary representing color mappings: ecah class assigns to a unique color code
        returns:
            out: A Binary Mask of shape [batch_size, classes, h, w]
    '''
    
    # num_classes is equal to len(mask) 
    num_classes = len(id2code)
    
    # it makes a tensor of shape h,w,num_classes:(720,960,num_classes)
    shape = img.shape[:2]+(num_classes,)
    
    # it makes a tensor with given shape and with type float64
    out = np.zeros(shape, dtype=np.float64)
    
    # 
    for i, cls in enumerate(id2code):
        
        #print(f'i: {i}, cls: {cls}')
        
        # img.reshape((-1,3)) flats mask except in channels
        
        # it reads thecolor code for a multiplication of higght and width and if it is one of the color code of 
        # the classes that we have then the third dimension takes the label of that class and the first
        # two dimsnions return to the hight and width  
        out[:,:,i] = np.all(np.array(img).reshape((-1,3)) == id2code[i], axis=1).reshape(shape[:2])
        
        # out: hight, width, class
        # returns class, hight, width
    return(out.transpose(2,0,1))

I expect for a mask in training set, when I convert it to a binary mask using function rgb_to_mask and then convert it again to an rgb-mask using function mask_to_rgb the result would be the same as the original mask. But they are not the same as it can be seen in the following code. I do not know where the problem is.

print(f'mask_sample_shape: {mask_sample.shape} mask_sample: {mask_sample.dtype}, mask_type: {type(mask_sample)}')
_, id2code,_,_ = Color_map(os.path.join(path,'class_dict.csv'))
mask_cls = rgb_to_mask(mask_sample, id2code)
print(f'mask_cls: {mask_cls.shape} mask_cls:{mask_cls.dtype} mask_cls_type: {type(mask_cls)}')
## Now converting mask_cls to mask_rgb
mask_rgb = mask_to_rgb(mask_cls[np.newaxis,...], id2code)
mask_rgb = mask_rgb.squeeze(0)
print(f'mask_rgb_shape: {mask_rgb.shape} mask_rgb: {mask_rgb.dtype}, mask_rgb_type: {type(mask_rgb)}')
comparison = mask_rgb==mask_sample
print(comparison.all())

out:
mask_sample_shape: (720, 960, 3) mask_sample: float32, mask_type: <class 'numpy.ndarray'>
mask_cls: (32, 720, 960) mask_cls:float32 mask_cls_type: <class 'numpy.ndarray'>
mask_rgb_shape: (720, 960, 3) mask_rgb: float32, mask_rgb_type: <class 'numpy.ndarray'>
False

Dose anyone have any idea where is the problem? I also visualized both masks at the end but the result was not the same.

Thanks