Error visualizing my Unet output

@ptrblck @Neda I referred to your discussion to map my multiclass segmentation problem. The output shape of my network matches that of my input i.e [8, 4, 572, 572] 8 is my batch size and 4 is the number of classes in my image including the background. I tried this code snippet to visualize my network output
output = model(images)
print(output.size()) #*[8,4,572,572]
prob = F.softmax(output, 1)
print(prob.size())# [8,4,572,572] outputs match
prob_imgs = torchvision.utils.make_grid(output.permute(1, 0, 2, 3))
print(prob_imgs.size()) #RuntimeError: The size of tensor a (3) must match the size of tensor b (8) at non-singleton dimension 0
I dont know where I am wrong. Could someone help me out.

update: output = model(images)
print(output.size())
pred = torch.argmax(output, dim=0).detach().cpu()
plt.imshow(pred.permute(1,2,0))
plt.show()

This code snippet solved this issue