Hi,
I am using torchdiffeq
's (https://github.com/rtqichen/torchdiffeq) neural ODE solvers and trying to calculate gradients “locally” inside of the module required by the method (ODEfunc
). I am not using any optimizer since I have no trainable parameters.
Here is my simplified code:
import numpy as np
import torch
# !pip install torchdiffeq # uncomment to install
from torchdiffeq import odeint_adjoint as odeint
def loss_fun(x):
return x[0]**2 + x[1]**2
class ODEfunc(nn.Module):
def __init__(self):
super(ODEfunc, self).__init__()
def forward(self, t, w):
_w = w.clone().detach().requires_grad_(True)
print(_w)
loss_val = loss_fun(_w)
print(loss_val)
loss_val.backward()
grad = _w.grad
return grad
w0 = torch.tensor([-0.5, 2.5])
odefun = ODEfunc()
path = odeint(odefun, w0, torch.tensor([0, 1]).float(), rtol=1e-3, atol=1e-3)
I get the following output:
tensor([-0.5000, 2.5000], requires_grad=True)
tensor(6.5000)
and the error on the loss_val.backward()
line:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
However, if I run the code for getting the gradients outside of the ODEfunc
:
_w = w0.clone().requires_grad_(True)
print(_w)
loss_val = loss_fun(_w)
print(loss_val)
loss_val.backward()
grad = _w.grad
I get:
tensor([-0.5000, 2.5000], requires_grad=True)
tensor(6.5000, grad_fn=<AddBackward0>)
no errors and correct values of the gradient.
The difference is in the second part my loss function has grad_fn
. I don’t understand why wrapping the code in another nn.Module breaks the calculation of the gradients. Any help and workaround will be great!