I’m currently writing a library on top of torchdiffeq that optimizes spacecraft trajectories with autograd. I’ve written a function that computes the time derivative of the spacecraft state, and I pass that to the odeint_adjoint()
function from torchdiffeq, which numerically integrates the ODE and uses the adjoint method to find the gradient of the loss wrt the initial conditions during the backward pass.
The problem I’m running into is that part of my dynamics function needs to use torch.autograd.grad
to compute a gradient (specifically, the gradient of the Hamiltonian wrt to the spacecraft state)— that is, the forward pass of my model involves using autograd. The odeint_adjoint()
implementation calls my function in a torch.no_grad()
context, so I’m using torch.enable_grad()
in my forward pass to compute the gradient. But no matter what I try, autograd seems to think that the state tensor is not used to compute the Hamiltonian, even though it certainly is. I either get a “No grad accumulator for a saved leaf!” error, or a “One of the differentiated Tensors appears to not have been used in the graph.” error. The former happens when I call detach().requires_grad_()
on the state tensor, and the latter happens if I don’t do that. Here’s the relevant code:
def forward(self, t_hat: Tensor, y: Tuple[Tensor, Tensor]) -> Tensor:
x, costate = y
# Re-dimensionalize
tof = self.tof
t = self.t0 + t_hat * tof
body_positions = [body(t)[0] for body in (self.attractor, self.source, self.target)]
# The odeint function will call us with torch.no_grad, but we need to compute the
# gradient of the Hamiltonian wrt the state as part of the basic dynamics
grad_was_enabled = torch.is_grad_enabled()
with torch.enable_grad():
# This is the line of code that changes the type of error message I get
x = x.detach().requires_grad_()
# Decode position, velocity, and mass
rv, m = x[:6], x[6]
r, v = rv[:3], rv[3:]
# Compute gravity
separations = (body_pos - r for body_pos in body_positions)
gravity = sum(
sep * mu / sep.norm(dim=0, keepdim=True).maximum(radius) ** 3
for sep, mu, radius in zip(separations, self.mus, self.radii)
)
assert gravity.isfinite().all(), "Got infinite gravity value"
thrust_mag, thrust_hat = self.get_controls(costate, m)
thrust = thrust_mag * thrust_hat
thrust_acc = torch.where(m > 0.0, thrust / m, 0.0) # Thrust must be zero when we have zero mass
# Compute time derivative of the state
x_dot = torch.cat([
v, # Change of position = velocity
gravity + thrust_acc, # Change of velocity = acc. due to gravity + acc. due to thrust
torch.where(m > 0.0, -thrust_mag / self._v_e, 0.0) # Change of mass = -thrust / exhaust velocity
]) * tof
hamiltonian = torch.sum(costate * x_dot) - thrust_mag.sum()
# Identity: dλ/dt = -dH/dx.
lambda_dot = -torch.autograd.grad(hamiltonian, x, create_graph=grad_was_enabled, retain_graph=grad_was_enabled)[0]
return torch.cat([x_dot, lambda_dot])
I know this is sort of a complicated problem, but any suggestions about how I could solve this would be greatly appreciated, since I am quite stumped.