Getting proper class predictions in multi-class segmentation

This is for channel 0:

c = plt.imshow(pred.data[0, 0, :, 20, :].cpu().numpy())

image

This is for channel 5:

c = plt.imshow(pred.data[0, 5, :, 20, :].cpu().numpy())

image

c = plt.imshow(pred.data[0, 1, :, 20, :].cpu().numpy())

Channel 1:
image