I followed your suggested post, then it works well to me, thanks for raising up the issue
for idx, (x, y) in enumerate(loader):
x = x.to(device=device)
with torch.no_grad():
preds = torch.sigmoid(model(x))
out = (preds > 0.5).float()
class_to_color = [torch.tensor([0.0, 0.0, 0.0]),
torch.tensor([14, 1, 133]), torch.tensor([33, 255, 1]),
torch.tensor([243, 5, 247]), torch.tensor([(255, 0, 0)])] #colors' num is equal to out_channels
output = torch.zeros(1, 3, out.size(-2), out.size(-1), dtype=torch.float)
for class_idx, color in enumerate(class_to_color):
mask = out[:,class_idx,:,:] == torch.max(out, dim=1)[0]
mask = mask.unsqueeze(1) # should have shape 1, 1, 180, 100
curr_color = color.reshape(1, 3, 1, 1)
segment = mask*curr_color # should have shape 1, 3, 180, 100
output += segment
print('saved pic shape {}, origin shape {}'.format(output.shape, y.shape))
torchvision.utils.save_image(output, f"{folder}/pred_{idx}.jpg")