Many channel input CNN out of memory

I have a CNN that processes inputs that have 10 input channels (these are not regular images), and then does segmentation on the pixels to categorize each pixel as a class.

But I am running out of memory on my GPU.

What is the recommended approach to take here? I think my input itself is so large due to being many pixels in length, width, and depth.

Should I grid the input into smaller boxes, and then train the network that way? My concern with this approach is that then the CNN won’t see the full context around a given box thus possibly hurting performance.

Change the batch size to be smaller

If reducing the batch size is not an option, you could also have a look at torch.utils.checkpoint to trade compute for memory.

1 Like