Yes the shape is correct.
Here is the code:
pred = torch.sigmoid(model(x))
out = (pred > 0.5).float()
print(f"out shape: {out.shape}\n")
class_to_color = [torch.tensor([0.0, 0.0, 0.0]), torch.tensor([10, 133, 1]), torch.tensor([14, 1, 133]), torch.tensor([33, 255, 1]), torch.tensor([243, 5, 247]), torch.tensor([(255, 0, 0)])]
output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
for class_idx, color in enumerate(class_to_color):
mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
print(f"{mask}\n")
mask = mask.unsqueeze(1) # should have shape 1, 1, 100, 100
print(f"mask shape: {mask.shape}\n")
curr_color = color.reshape(1, 3, 1, 1)
print(f"color shape: {color.shape}\n")
segment = mask*color # should have shape 1, 3, 100, 100
output += segment
torchvision.utils.save_image(output, f"{folder}/pred_{idx}.png")
Here is the corresponding output:
out shape: torch.Size([1, 6, 100, 100])
tensor([[[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
...,
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True],
[True, True, True, ..., True, True, True]]], device='cuda:0')
mask shape: torch.Size([1, 1, 100, 100])
color shape: torch.Size([3])
And the error:
segment = mask*color # should have shape 1, 3, 100, 100
RuntimeError: The size of tensor a (100) must match the size of tensor b (3) at non-singleton dimension 3