[SOLVED] Index to rgb , tensor casting from given cmap

Hi,

I have a given cmap for labels in a dataset with 35 classes such as a 35x3 matrix, where each tuple is a value corresponding to the r,g,b value of pixel for some class.

My final semantic segmentation output is of size <batch, 35, height, width>
So doing the following I get the indexes of the pixels having the highest probability

num_classes = 35
output = output_from_segmentation_model()
_, pred = torch.max(output, dim=1)

Hence pred is a tensor with the index values where pred.shape = [batch, height, width] and pred.min = 0, pred.max = 34

Now I want to cast this an 3 channel image :
img = torch.zeroes(output.size[0], output.size[2], output.size[3])

such that for the respective value of pixel in pred the image has r,g,b frrom that class in cmap, where cmap is given as follows :

cityscapes_map = np.array([[0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.07843137, 0.07843137, 0.07843137],
       [0.43529412, 0.29019608, 0.        ],
       [0.31764706, 0.        , 0.31764706],
       [0.50196078, 0.25098039, 0.50196078],
       [0.95686275, 0.1372549 , 0.90980392],
       [0.98039216, 0.66666667, 0.62745098],
       [0.90196078, 0.58823529, 0.54901961],
       [0.2745098 , 0.2745098 , 0.2745098 ],
       [0.4       , 0.4       , 0.61176471],
       [0.74509804, 0.6       , 0.6       ],
       [0.70588235, 0.64705882, 0.70588235],
       [0.58823529, 0.39215686, 0.39215686],
       [0.58823529, 0.47058824, 0.35294118],
       [0.6       , 0.6       , 0.6       ],
       [0.6       , 0.6       , 0.6       ],
       [0.98039216, 0.66666667, 0.11764706],
       [0.8627451 , 0.8627451 , 0.        ],
       [0.41960784, 0.55686275, 0.1372549 ],
       [0.59607843, 0.98431373, 0.59607843],
       [0.2745098 , 0.50980392, 0.70588235],
       [0.8627451 , 0.07843137, 0.23529412],
       [1.        , 0.        , 0.        ],
       [0.        , 0.        , 0.55686275],
       [0.        , 0.        , 0.2745098 ],
       [0.        , 0.23529412, 0.39215686],
       [0.        , 0.        , 0.35294118],
       [0.        , 0.        , 0.43137255],
       [0.        , 0.31372549, 0.39215686],
       [0.        , 0.        , 0.90196078],
       [0.46666667, 0.04313725, 0.1254902 ],
       [0.        , 0.        , 0.55686275]])

something in pseudocode :
img[…] = img[cmap[pred[x]]]

Please advice on correct syntax…

Discussion continued here : How to visualize segmentation output - multiclass feature map to rgb image?