How to save multi-class segmentation prediction as image?

There might be a utility function somewhere that does this, but you can write your own function like:

    class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
    output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
    for class_idx, color in enumerate(class_to_color):
        mask = out[:,class_idx,:,:] == torch.max(out, dim=1)
        mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
        curr_color = color.reshape(1, 3, 1, 1)
        segment = mask*color # should have shape 1, 3, 100, 100
        output += segment