I think I found a decent solution to this. If anyone has another approach please share!
from matplotlib import cm
...
x = torch.from_numpy(img_vol)
y = torch.from_numpy(pred_vol)
cmap_vol = np.apply_along_axis(cm.viridis, 0, y.numpy()) # converts prediction to cmap!
cmap_vol = torch.from_numpy(np.squeeze(cmap_vol))
show(make_grid(x), make_grid(y), alpha=0.5)
plt.show()