Torch.utils make_grid with cmaps?

I think I found a decent solution to this. If anyone has another approach please share!

    from matplotlib import cm
    ...
    x = torch.from_numpy(img_vol)
    y = torch.from_numpy(pred_vol)
    cmap_vol = np.apply_along_axis(cm.viridis, 0, y.numpy()) # converts prediction to cmap!
    cmap_vol = torch.from_numpy(np.squeeze(cmap_vol))

    show(make_grid(x), make_grid(y), alpha=0.5)
    plt.show()

2 Likes