Offload models to CPU using autograd.Function

I was wondering if it was possible to do something like the following, where I try to load the model from CPU -> GPU before the computation and send it back after:

import torch
from torch import nn

DUMMY = torch.empty((), requires_grad=True)
class Clive(torch.autograd.Function):
    def forward(ctx, layer, dummy, inputs):
        inputs = inputs
        with torch.enable_grad():
            outputs = layer(inputs)
        ctx.layer = layer
        ctx.inputs = inputs
        ctx.outputs = outputs'cpu')
        return outputs.detach()

    def backward(ctx, *grad_outputs):

        diff_with = ctx.layer.parameters()
        if ctx.inputs.requires_grad:
            diff_with = itertools.chain([ctx.inputs], diff_with)

        grads = torch.autograd.grad(outputs=(ctx.outputs,), inputs=diff_with, grad_outputs=grad_outputs)
        if ctx.inputs.requires_grad:
            input_grad, *grads = grads
            input_grad = None

        for g, l in zip(grads, ctx.layer.parameters()):
        return (None, None, input_grad)

m1 = nn.Linear(10, 10)
x = torch.randn(10, 10, device=0)
Clive.apply(m1, DUMMY, x).sum().backward()

Unfortunately this errors with One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.. I’m still working to understand this error, can anyone dumb it down for me? If I get rid of the .to(...) method calls it works ok.

(I don’t really know if this is practical/useful, I’m just curious if I can actually do it :D)


Moving across devices is differentiable. So you don’t have to use custom Function if you just want to use different devices.

1 Like

I was also interested in offloading the computed outputs to the CPU (like gradient checkpointing but not recomputing), but you’re right I can just wrap the model.

Speaking of, is there something obvious I am doing wrong here? It seems to fail if I try to run another backward pass when there’s already gradients allocated.

import torch
from torch import nn
class Wrapped(nn.Module):
    def __init__(self):
        self.l = nn.Linear(10, 10)
    def forward(self, x):
        ans = self.l(x)'cpu')
        return ans

inp = torch.randn(10, 10, device=0)
w = Wrapped()

# Standard forward/backward

# Clear grads and do it again
for p in w.parameters():
    p.grad = None

# Zero grads like we should do, fails with "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"

Actually it doesn’t seem to work for my use case - I have to reload the model to GPU before running the backward pass. I’ve mimicked a trivial training loop here:

import torch
from torch import nn

m = nn.Linear(10, 10)
inputs = [
    torch.randn(10, 10, device=0),
    torch.randn(10, 10, device=0)
for input in inputs:
    activation = m(input)'cpu')
    loss = activation.sum()
    # <--- without this copy, it doesn't work

Obviously in this example this is fine, but what I’m aiming to do is only load the layer’s weight into GPU memory when they are needed (e.g. forward and backward). That was what I was going for with the autograd Function.


You shouldn’t use the functions inside your forward. It won’t work well with AD.
Only the that you were mentioning above will.

The reason for this is that the Module Parameters are always leaf Tensor. And so cannot be updated in a differentiable manner.

Note that more generally, I don’t think your approach will help wrt to GPU memory as the autograd needs to keep values to be able to compute the backward pass. And your model’s weights are part of them. So even if you move the model to the cpu, these buffers will stay on the GPU.

Hi, @albanD

Thanks for the explanation! I’m curious that is there any simple way that I can offload the weight / intermediate output to CPU memory while getting back to GPU memory when it needs to (e.g., during backward, or at other place some op needs it)? Here I mean the true offloading that can save GPU memory.

I’m trying to make it by autograd function, like in-placement update on tensor data, but still not work (some backward error on gradient format)

I’m afraid there is no simple way to do this today.
We are working on adding this though: SavedVariable default hooks · Issue #58659 · pytorch/pytorch · GitHub (this should be done in 1-2 months).

The workaround with using custom Function would work but you will need to save on the ctx foo.cpu() and not use save_for_backward. And then in the backward formula do
But that means that you have to rewrite most of the backward formulas which is very painful.

Great thanks! Looking forward to the new functionality!