Can I train one network with different image scales

My work is medical image segmentation, becasuse of its high image dimension. I always train the segmentation network with smaller patches. But this will lost the global info of the organs. So I’d like to use the whole 3D or 2D image to update the network during the last iteration of every epoch. But I am not sure that’s feasible and how to make it happen.

If the original image resolution yields an out of memory error on your device, you could either try to:

  • apply torch.utils.checkpoint to trade compute for memory
  • use model sharding, if you have more than a single device, and execute parts of the model on separate devices
  • execute the last iteration on the CPU, since you would usually have more system RAM than GPU memory
  • try mixed precision training for a potentially reduced memory usage and speedup (you would need to install the nightly binaries for this or build from master).
1 Like