This is for channel 0:
c = plt.imshow(pred.data[0, 0, :, 20, :].cpu().numpy())
This is for channel 5:
c = plt.imshow(pred.data[0, 5, :, 20, :].cpu().numpy())
c = plt.imshow(pred.data[0, 1, :, 20, :].cpu().numpy())
Channel 1:
This is for channel 0:
c = plt.imshow(pred.data[0, 0, :, 20, :].cpu().numpy())
This is for channel 5:
c = plt.imshow(pred.data[0, 5, :, 20, :].cpu().numpy())
c = plt.imshow(pred.data[0, 1, :, 20, :].cpu().numpy())
Channel 1: