I have a softmax output of a segmentation network which after argmax is of size (batch_size, Width, Height) and consists of (0,1,2,3) which are classes assigned to each pixel in image. Now I’m trying to convert this tensor into RGB image where each class is assigned diffrent color. I came up with this code

```
def convert_tensor_to_RGB(network_output):
converted_tensor = torch.zeros([network_output.size(0), 3, network_output.size(1), network_output.size(2)])
x = {
0: torch.Tensor([0, 0, 0]),
1: torch.Tensor([255, 0, 0]),
2: torch.Tensor([0, 0, 255]),
3: torch.Tensor([0, 255, 0]),
}
for i in range(network_output.size(0)):
sample = network_output[i, :, :]
for j in range(len(x)):
converted_tensor[i, sample == j] = x[j]
return converted_tensor
```

but I get an error: IndexError: The shape of the mask [540, 960] at index 0 does not match the shape of the indexed tensor [3, 540, 960] at index 0

So is there any efficient and fast way to do this convertion?