Ah, curr_color
should be used instead:
class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
for class_idx, color in enumerate(class_to_color):
mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
curr_color = color.reshape(1, 3, 1, 1)
segment = mask*curr_color # should have shape 1, 3, 100, 100
output += segment
However, note that you should make sure the color formatting is consistent (e.g., either floating point values between 0.0 and 1.0 or integers between 0 and 255).