How to manage huge input tensor in PyTorch

Hello everyone,
I am currently working on a model that takes in a large volume with the shape (1, 200, 300, 300), where the first axis represents the number of channels, and the subsequent axes represent the width, height, and depth of the volume, respectively. The objective is to generate a segmentation map of the same shape as the input.

However, I am facing challenges due to the substantial memory requirements of the activation maps. Despite employing down/upscaling techniques, the combined size of the first and last activation maps amounts to approximately 12GB, which exhausts the memory of a single GPU.

While I am aware of patch-based training, I would prefer to explore alternative approaches for several reasons. Model parallelism seems like a promising solution, but it may be highly inefficient, as it would utilize a single GPU for a limited number of convolutions while handling large volumes.

Ideally, I would like to split the input tensor across multiple GPUs, allowing for parallel processing, and then aggregate the results into a single tensor. However, I have been unable to find any libraries or frameworks that provide such functionality.

As an example, let’s consider applying a 3x3 convolutional filter with a stride of 1 and padding of 1 to a 1x10x10 image. We can split the image into four chunks, resulting in four 1x6x6 images, and send each of them to a different GPU. Next, I can have all GPUs apply the same conv filter to these different “chunks” of the input and, afterward, combine these chunks back into a single tensor.

I would greatly appreciate any guidance or suggestions on how to address this problem more effectively. Thank you in advance for your assistance.

Best regards,
Luca

https://www.kernel-operations.io/keops/index.html
You could always use pytorch keops, which allows for differentiable operations with pytorch on huge tensors (their examples use 2000000x3 tensors). Not sure if DDP support is a firm yes, their github is a bit hazy on that.

Hi Andrew, thanks for referencing me such a cool library!
I tried to dig into it but I’m not sure how to employ it to perform a convolution (or a transpose convolution) using PyTorch. Moreover, I couldn’t find anyone employing such a library to perform classical operations such as Linear or Convolution. Looks like it doesn’t really fit my need, because as soon as i try to instantiate a 3D tensor as a LazyTensor, it complains about the fact that its shape should be one of (…,M,1,D), (…,1,N,D) or (…,1,1,D), while mine would be (B, C, M, N, D) as it is a 3D volume.

I’m not familiar with your particular use case. I have used pykeops in the past to implement 3D point cloud convs.

As a generic recommendation:

  1. You can try depth-wise convs for 3D, think like mobilenet-like networks
  2. You can try grouped convs for 3D, think like shufflenet like networks
    – For 1/2 you can use https://github.com/okankop/Efficient-3DCNNs which I’ve used in the past, this can sharply cut down on the memory usage
  3. You can try activation checkpointing (aka gradient checkpointing), see the pytorch torch.utils.checkpoint.checkpoint function
  4. As an alternative to 3, Deepspeed provides the deepspeed.checkpointing.checkpoint function, which provides the checkpoint_in_cpu option, when turned on will move the activations into memory rather than recomputing them. I’ve never used the deepspeed implementation.
  5. If you are running out of memory in the backwards pass, you can try FSDP, which is an official pytorch function.