Having trouble calculating gradients locally inside nn.Module


I am using torchdiffeq's (https://github.com/rtqichen/torchdiffeq) neural ODE solvers and trying to calculate gradients “locally” inside of the module required by the method (ODEfunc). I am not using any optimizer since I have no trainable parameters.

Here is my simplified code:

import numpy as np
import torch
# !pip install torchdiffeq # uncomment to install
from torchdiffeq import odeint_adjoint as odeint

def loss_fun(x):
    return x[0]**2 + x[1]**2

class ODEfunc(nn.Module):
    def __init__(self):
        super(ODEfunc, self).__init__()

    def forward(self, t, w):
        _w = w.clone().detach().requires_grad_(True)
        loss_val = loss_fun(_w)
        grad = _w.grad
        return grad

w0 = torch.tensor([-0.5, 2.5])
odefun = ODEfunc()
path = odeint(odefun, w0, torch.tensor([0, 1]).float(), rtol=1e-3, atol=1e-3)

I get the following output:

tensor([-0.5000,  2.5000], requires_grad=True)

and the error on the loss_val.backward() line:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

However, if I run the code for getting the gradients outside of the ODEfunc:

_w = w0.clone().requires_grad_(True)
loss_val = loss_fun(_w)
grad = _w.grad

I get:

tensor([-0.5000,  2.5000], requires_grad=True)
tensor(6.5000, grad_fn=<AddBackward0>)

no errors and correct values of the gradient.

The difference is in the second part my loss function has grad_fn. I don’t understand why wrapping the code in another nn.Module breaks the calculation of the gradients. Any help and workaround will be great!

I turnes out there is the torch.no_grad() inside the odeint's forward pass. I fixed my problem using torch.enable_grad() inside forward in ODEfunc.