3D Images too large, UNet

I have several 3D image datasets. These images are too large to fit on the GPU we currently have (RTX 4090, 24GB).

I have designed and implemented Lightning modules using MONAI’s UNet. The original implementation involved downsampling the images using MONAI’s spacing transform in order to fit on the GPU. I would prefer to train on full resolution images.

I started to look into how to process these images and for the training step, I found that I could use crop transforms to reduce the image size and collect chunks/patches that way. The validation step, from my understanding, shouldn’t have these types of transforms. (Example: RandCropByPosNegLabel is the most common one).

I started to use MONAI’s sliding window inference, which handles the prediction side. However this will return the entire image’s prediction as one tensor. This tensor is too large and has to be pushed to the CPU. Then the loss function and other metrics must be performed on the CPU. This leads to exceptionally slow training times.

I’ve also tried using MONAI’s GridPatchDataset, but this seems like it’s creating the patches in between each training step, and leads to exceptionally slow training times as well. Even slower than the above mention.

I’m unsure if I’m just googling the wrong things, but I can’t seem to find any solutions that help. I’m already down to batch_size=1, precision=16, etc.

My question is: How are people handling such large images for UNet image segmentation?

Hi Goofy!

My recommendation would be to follow the “tiling strategy” mentioned in the
original U-Net paper. Note that for tiling to work cleanly, the U-Net needs to be
designed and used properly. (It is well worth reading that original paper with care.)

(Tiling means cutting the input image into overlapping tiles, passing them individually
through the U-Net, and then sewing the output tiles back together to construct the
entire output image. If your U-Net is properly designed, when you do this you get
the same result – up to numerics – as if you had passed the entire input image
through the U-Net in one go.)

Working with 3D naturally makes greater memory demands. U-Net has a “field of
view” – that is, the size of the region of the input image that affects a given output
pixel. You need to be able to fit a field of view in memory for tiling to work (without
making complicated and performance-degrading modifications). I would expect
that you could fit a field-of-view tile for a “typical” 3D U-Net into 24GB.

I’m not aware of any pre-packaged tiling code (although something could be out
there). Tiling is in principle straightforward, but requires some detail-oriented
coding to get right.

Best.

K. Frank

I certainly wish I had a mentor to ask these types of questions but I don’t. I greatly appreciate your reply.

I shall go through the UNet paper again and look up the tiling mechanics. My work I think calls it chunking or patching. But I understand the jist of what you mean. I’m more worried about the difficulty of being able to rebuild the image for test metrics/predictions and edge artifacts. But I guess it’s time for me to dig into that and learn.

I was certainly hoping for something more in the prepackaged realm, but I’ve already been feeling like this would require me to start modifying packages or create my own.

Hi Goofy!

If you’re worried about artifacts due to the edges of the tiles (patches), there won’t
be any artifacts if your U-Net is properly designed and you implement tiling correctly.
(Running the entire image through the U-Net all at once will give the same result – up
to possible round-off error – as running the image through tile by tile, assuming your
tiling strategy has been implemented properly, so, by definition, there won’t be any
artifacts.)

If you’re worried about artifacts at the edges of the entire image, yes, that could
happen. But such artifacts wouldn’t be specific to tiling or the U-Net architecture.

U-Net has been quite successful and widely used, so I wouldn’t rule out there being
some pre-packaged tiling implementation. It would probably be worth some effort to
see if you could find one; I just don’t know of any off-hand.

Good luck!

K. Frank