There might be a utility function somewhere that does this, but you can write your own function like:
class_to_color = [torch.tensor([1.0, 0.0, 0.0]), ...]
output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
for class_idx, color in enumerate(class_to_color):
mask = out[:,class_idx,:,:] == torch.max(out, dim=1)
mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
curr_color = color.reshape(1, 3, 1, 1)
segment = mask*color # should have shape 1, 3, 100, 100
output += segment