Getting multi-class segmentation mask from multi-channel model prediction

I have 9 channel(class) 3D prediction of size: b,C,48,64,64. (b for batch size, and prediction is 0~1)

In original segmentation mask the class labels are: labels = [0, 1, 2, 3, 4, 8, 10, 11, 56]
Are there convenient ways to reconstruct the final output with shape: b,48,64,64 with all the label values assigned, something opposite to one-hot encoding.

Moreover, I got the following code:

pred = model(img)
pred = pred.data.max(1)[1].cpu().numpy()[:, :, :].astype('uint8')

Does it help somehow?