Unfolding 5D input in memory-friendly manner?

I’m currently trying to train a 3D U-Net, in which the convolutional layers are replaced with my own layers. In my own implementation, I unfold the (5D: batch size x channels x h x d x w) input to I can contract it with my kernel tensor, which means I use torch.Tensor.unfold three times per layer, like so:

patches = input.unfold(2, kernel_dim, stride).unfold(3, kernel_dim, stride).unfold(4, kernel_dim, stride)

Unfortunately, this is also extremely taxing on the GPU’s RAM.

I’m currently training on an HPC, and have access to multiple GPUs, each with 32GB RAM. To save memory, I’m currently trying to checkpointing with cpu_offload (on each of my double convolution blocks), using the fairscale library.

I haven’t found a way to make it work just yet. I’m wondering: is there a memory-friendly alternative to unfolding 5D input this way? I found several posts of people having similar problems (which is how I turned to fairscale), but I haven’t found anything just yet.