How to save multi-class segmentation prediction as image?

Ah, curr_color should be used instead:

    class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
    output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
    for class_idx, color in enumerate(class_to_color):
        mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
        mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
        curr_color = color.reshape(1, 3, 1, 1)
        segment = mask*curr_color # should have shape 1, 3, 100, 100
        output += segment

However, note that you should make sure the color formatting is consistent (e.g., either floating point values between 0.0 and 1.0 or integers between 0 and 255).