Saving the autograd state of a model for later backprop

I’m looking to be able to save all the state data generated during the forward pass so that I can perform backprop at any time. The reason I want to do this is to be able to train large, composite models on a home computer with one GPU.

Here’s the general scheme I have in mind:

  1. Load and compute forward() on one model, collect the outputs.
  2. Load another model, compute forward() with the outputs from the first.
  3. Compute backward on second model and get the gradient for it’s inputs (updating its parameters along the way)
  4. Re-load the first model and compute the backward pass with the gradients obtained from (3).

Ideally, I’d keep the parameters for both models in GPU memory and just swap the intermediate and optimizer states for the models in and out of memory. I realize this might be quite slow, but I’d really like to experiment with it.

I’ve done a bit of tinkering trying to get this to work already. It seems like the intermediate states that I am trying to get at are likely being stored in the tensors returned by each operation. E.g.:

x, y = torch.randn(<>), torch.randn(<>)
t1 = x + y
t2 = t1 * x
output = tfinal.softmax()

In this example, the ‘t’ variables seem to contain the state within their ‘data’ property. I found this by building two identical graphs with different inputs, performing a forward on each, then copying that data property for each intermediate tensor from one graph to the other. The backprop then computes the same grads for each graph.

The thing is, this isn’t a very robust approach. First of all, it is pretty much impossible to get these intermediate tensors for some operations. It’s also a real pain to record them. Finally, you cannot do this with the nn package because it’ll throw an error when you try to do backprop on the doctored model.

I think the solution to this might involve tinkering with the C++ code. It seems like the grad_fn objects built up by autograd must contain references to this state data. If I could access this state data from these grad_fn objects, I could just crawl the graph backward and copy the state data to and from them. Can anyone give me any tips on where in the cpp code I should look to find these states? Maybe some tips on alternative approaches?

1 Like

I would try to either:

  • reduce the batch size and try to train both models on the device
  • if you have two GPUs, use each GPU for a single model
  • try to use torch.utils.checkpoint to trade compute for memory

Hi ptrblock,
I appreciate the tips. I think these would work for what I’m currently looking at (well, minus the two GPUs… :slight_smile:).

I would still like to look in to this deferred backprop approach, though. If for no other reason than that I’m curious what kind of performance penalties it would entail. If anyone has any tips on where the backprop graph’s states are stored, I’d really appreciate it.

For anyone else who finds this question, I managed to get something like what I wanted by modding the code on my own. You’ll need to build your own libpytorch if you want to use it (for now):

I’ve filed a feature request in case the mods are interested in bringing something like this into mainline:

1 Like