Cryptic "No grad accumulator for a saved leaf!" error

I’m currently writing a library on top of torchdiffeq that optimizes spacecraft trajectories with autograd. I’ve written a function that computes the time derivative of the spacecraft state, and I pass that to the odeint_adjoint() function from torchdiffeq, which numerically integrates the ODE and uses the adjoint method to find the gradient of the loss wrt the initial conditions during the backward pass.

The problem I’m running into is that part of my dynamics function needs to use torch.autograd.grad to compute a gradient (specifically, the gradient of the Hamiltonian wrt to the spacecraft state)— that is, the forward pass of my model involves using autograd. The odeint_adjoint() implementation calls my function in a torch.no_grad() context, so I’m using torch.enable_grad() in my forward pass to compute the gradient. But no matter what I try, autograd seems to think that the state tensor is not used to compute the Hamiltonian, even though it certainly is. I either get a “No grad accumulator for a saved leaf!” error, or a “One of the differentiated Tensors appears to not have been used in the graph.” error. The former happens when I call detach().requires_grad_() on the state tensor, and the latter happens if I don’t do that. Here’s the relevant code:

def forward(self, t_hat: Tensor, y: Tuple[Tensor, Tensor]) -> Tensor:
        x, costate = y

        # Re-dimensionalize
        tof = self.tof
        t = self.t0 + t_hat * tof
        body_positions = [body(t)[0] for body in (self.attractor, self.source, self.target)]

        # The odeint function will call us with torch.no_grad, but we need to compute the
        # gradient of the Hamiltonian wrt the state as part of the basic dynamics
        grad_was_enabled = torch.is_grad_enabled()
        with torch.enable_grad():
            # This is the line of code that changes the type of error message I get
            x = x.detach().requires_grad_()

            # Decode position, velocity, and mass
            rv, m = x[:6], x[6]
            r, v = rv[:3], rv[3:]

            # Compute gravity
            separations = (body_pos - r for body_pos in body_positions)
            gravity = sum(
                sep * mu / sep.norm(dim=0, keepdim=True).maximum(radius) ** 3
                for sep, mu, radius in zip(separations, self.mus, self.radii)
            )
            assert gravity.isfinite().all(), "Got infinite gravity value"

            thrust_mag, thrust_hat = self.get_controls(costate, m)
            thrust = thrust_mag * thrust_hat
            thrust_acc = torch.where(m > 0.0, thrust / m, 0.0)      # Thrust must be zero when we have zero mass

            # Compute time derivative of the state
            x_dot = torch.cat([
                v,                                                  # Change of position = velocity
                gravity + thrust_acc,                               # Change of velocity = acc. due to gravity + acc. due to thrust
                torch.where(m > 0.0, -thrust_mag / self._v_e, 0.0)   # Change of mass = -thrust / exhaust velocity
            ]) * tof
            hamiltonian = torch.sum(costate * x_dot) - thrust_mag.sum()

            # Identity: dλ/dt = -dH/dx.
            lambda_dot = -torch.autograd.grad(hamiltonian, x, create_graph=grad_was_enabled, retain_graph=grad_was_enabled)[0]

        return torch.cat([x_dot, lambda_dot])

I know this is sort of a complicated problem, but any suggestions about how I could solve this would be greatly appreciated, since I am quite stumped.

Your workaround if detaching the tensor and setting requires_grad_() on it might not work properly and the latter error seems to indicate that the issue might be indeed an unused parameter.
Could you post the shapes for each tensor to make your code snippet executable so that we could take a proper look at it?

Thanks for your reply. x and costate are vectors of shape [7], tof and thrust_mag are scalars of shape [1] and thrust_hat and gravity are vectors of shape [3].

Since my initial post I’ve been able to run the forward() method by itself in a torch.no_grad() context in a Jupyter notebook, so it seems like there’s something special about the odeint_adjoint function that is causing the problem. I might have to dive into the torchdiffeq source code to diagnose the problem, but it’s unclear to me how I could be getting different results calling forward() in a no grad context directly vs. calling odeint_adjoint which then calls forward().

Thanks for the update. Would I be able to use the posted shapes to create new (random) tensors and run your code snippet to reproduce the issue or is torchdiffeq a needed module to run into the errors?

I do think torchdiffeq is essential for reproducing the error. After a bit of print-debugging I’ve figured out that when I call odeint_adjoint, the state vector x is a leaf node that requires grad, while it is not a leaf node when I directly call forward() in a no_grad() context. Also, if I use the vanilla odeint function from torchdiffeq instead of odeint_adjoint, there is no error. So I think it’s potentially some bug in torchdiffeq— I’ll need to look into it some more.

So the error message describes very literally the situation it found: There is a leaf (i.e. no grad_fn) for a tensor but it doesn’t have an accumulator (to add things it gets to .grad). Unfortunately, you cannot get a representation of the accumulator in Python directly (interestingly, you can get it in the graph through grad_fn.next_functions, but I don’t know that you could directly).

This is strange because the things that make leaf tensors (requires_grad_, factory functions with requires_grad=True) should all reliably set up the accumulator as needed, so it looks like something is not setting up the grad_fn properly.

This essentially means that something in the autograd-set-up has been messed up when the tensor in question was created. I must admit it fascinates me, because while I did have it happen to me when I manually did stuff with the autograd graph in C++, I have not seen this happen from Python, but my guess is that the most likely way to trigger this is doing very funny stuff inside an autograd.Function.

If you have something that reproduces it with code that you’re willing to share, I would be most grateful.

Best regards

Thomas