Converting class tensor from segmentation to image

I have a softmax output of a segmentation network which after argmax is of size (batch_size, Width, Height) and consists of (0,1,2,3) which are classes assigned to each pixel in image. Now I’m trying to convert this tensor into RGB image where each class is assigned diffrent color. I came up with this code

def convert_tensor_to_RGB(network_output):
    converted_tensor = torch.zeros([network_output.size(0), 3, network_output.size(1), network_output.size(2)])
    x = {
        0: torch.Tensor([0, 0, 0]),
        1: torch.Tensor([255, 0, 0]),
        2: torch.Tensor([0, 0, 255]),
        3: torch.Tensor([0, 255, 0]),
    }
    for i in range(network_output.size(0)):

        sample = network_output[i, :, :]
        for j in range(len(x)):
            converted_tensor[i, sample == j] = x[j]
    return converted_tensor

but I get an error: IndexError: The shape of the mask [540, 960] at index 0 does not match the shape of the indexed tensor [3, 540, 960] at index 0
So is there any efficient and fast way to do this convertion?

On the line
converted_tensor[i, sample == j] = x[j]
you just forgot to index on the channel dimension:
converted_tensor[i, :, sample == j] = x[j]
but then converted_tensor[i, :, mask] would be shaped [3, ?] whereas x[j] is shaped [3], so you would probably need to unsqueeze it:

converted_tensor[i, :, sample == j] = x[j].unsqueeze(1)

EDIT:
A much better way to do this would be to use torch.nn.functional.embedding().
Put your color table in a [4, 3] float tensor:

x = torch.FloatTensor([[0, 0, 0], [255, 0, 0], [0, 0, 255], [0, 255, 0]])

Then you can directly have your result with:

converted_tensor = torch.nn.functional.embedding(network_output, x).permute(0, 3, 1, 2)

No need for a loop.

Thank you, it works great.