Seeking advice on computing backprop step by chunks

Hi,
I am facing a challenge with autograd when a computation temporarily involves a very large tensor that won’t fit on my GPU. If there wasn’t a need to retain gradients, a chunk-by-chunk calculation would work. However, autograd requires that I store this intermediate large tensor on my GPU.
Consider a pipeline: small tensor a → very large tensor b → small tensor c. It seems feasible that if the forward pass can be calculated chunk-by-chunk, so could the gradient of c with respect to a.

  • Concrete Scenario:
    Two tensors x, y of shape (256, 256, 128) and (256, 1024, 128) combine to form a large tensor L of shape (256, 256 * 1024, 2 * 128), by concatenating channels for each vector pairing from x and y. A series of linear layers then reduce this to a smaller tensor z of shape (256, 256 * 1024, 1).

I would like to backpropagate (the gradients of) z with respect to x and y. The way it works in my mind is: x and y are stored, the forward computation computes z by chunks, so it never has to materialize L completely. The backward computation calculates the gradients by chunks (this is possible because the linear layers are well behaved), similarly never materializing L completely.

  • Potential Solutions Explored:
    • A custom autograd function and class like the example in the docs. This almost suit my needs, but I would have to manually pass in tensors (weights and biases) for the linear layers (a lot of tensors if I have multiple linear layer, and I would have to reconstruct all the functionality of nn.Linear myself).
    • Gradient checkpointing, although in the backprop pass I still materialize the whole L tensor.

Any suggestions on handling this scenario would be appreciated!

Figured it out. I could directly pass a nn.Module in my custom autograd Function, and that was an acceptable design.