Computing gradient of output with respect to input -- RuntimeError: Trying to backward through the graph a second time during training

I am building a model that outputs a scaler, and I need to compute the gradient of those outputs with respect to their inputs. I am doing something like the following:



def get_forces(outputs, coordinates):
    forces = []
    for index, (hash, image) in enumerate(coordinates.items()):
        for i, coordinate in image: 
            forces_ = -torch.autograd.grad(
                                      outputs[index], 
                                      coordinate,  
                                      create_graph=True,
                                      retain_graph=True
             )[0]
            forces.append(forces_)
    return torch.stack(forces)


while not converged:
    optimizer.zero_grad()  # clear previous gradients
    outputs = model(inputs)    # Note that inputs have requires_grad = True
    outputs = torch.stack(list(outputs.values())).sum(0)
    

    all_grad = get_gradients(outputs, coordinates[0])

    outputs = {"scalar": outputs, "gradiend": all_grad}

    criterion = CustomLossFunction(gradient=True)
    loss = criterion(outputs, targets)    # This loss uses both scalar and gradient
    
    loss.backward()  
    optimizer.step()

    epoch += 1

    
    if epoch == stop:
        converged = True

The first epoch passes correctly, but at the second epoch I get:

RuntimeError: Trying to backward through the graph a second time, but 
he saved intermediate results have already been freed. Specify 
retain_graph=True when calling backward the first time.

How can this be solved? How can I compute the gradient of the output of my neural network with respect to its inputs without having this issue? Meanwhile, I am reading on the forum about how other people have solved this.

Hi,

This is most likely not linked to your get_gradients() function.
But more to the fact that something that you using in the loop (inputs or coordinates) already has a history (you made some operations with it while it was requiring gradients) and so this part of the graph is shared across iterations.
You want to make sure that this is not the case and that you re-create the graph at each iteration.

In particular here, if you get gradients for coordinates from outputs while only inputs were used. I guess you have a graph that links inputs and coordinates and this one is shared across iterations.
You should move that computation inside the while loop.

This makes total sense and it is the case. This is what I am doing:

Coordinates === mapping function ===> Features  ====> model ====> output

The training loop is effectively from Features instead of Coordinates (these have requires_grad=True). The reason for that is because features are not being optimized and to avoid recomputing them on each epoch I just calculated them once. After your explanation, my mapping function has to be moved into model or the training loop so that inputs are really Coordinates and not Features. Is there any way I can make this work without mapping coordinates into features at every epoch?

Hi,

There is no magic trick to force the autograd to only keep part of the graph I’m afraid.
If you want to avoid recomputing this, you can call .backward(retain_graph=True) which might increase you peak memory a little bit. Also if you do that, I would recommend you either:

  • Move the content of the while loop into a different function. To make sure that at the end of the loop, all the temporary variable are deleted before starting the next operation
  • Add a bunch of del foo at the end of your loop for outputs, all_grad, and loss. (again to make sure they don’t stay alive until the next iteration).
1 Like

Thanks @albanD. Your replies were very helpful. After moving the mapping function to the training loop I can see effectively that .backward() does not complain anymore. Now the next problem I am trying to solve is gradients are basically zero where they should not be :slight_smile: I hope to figure it out.

You can do .register_hook(fn) on any Tensor and fn will be called with the grad computed for this Tensor. That can help you print out gradients to track it down :slight_smile:
Good luck!

1 Like