Taking the derivative of the output of a Neural Network in the forward method w.r.t. input

Hello. I am trying to create an ODE solver. For this purpose, I wish to take the the derivative of the output of the network w.r.t. to its input in the forward method. Here is my code; note that t and N_0 are both of size (batch_size, 1), and I wish to have N_1 be of this size as well.

    def forward(self, t):
        N_0 = self.net(t)
        N_0.requires_grad = True
        ones = torch.ones_like(N_0)
        N_1 = torch.autograd.grad(N_0, t, ones, retain_graph=True)[0]
        return t, N_0, N_1  

but I get the error

RuntimeError: you can only change requires_grad flags of leaf variables.

in the line N_0.requires_grad = True. This makes sense to me, but my question is: How do I fix this? I have tried the following:

    def forward(self, t):
        t.requires_grad = True
        N_0 = self.net(t)
        N = N_0.detach()
        N.requires_grad = True
        ones = torch.ones_like(N)
        N_1 = torch.autograd.grad(N, t, ones, retain_graph=True)[0]
        return t, N_0, N_1  

but I get the error

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Note that I will be using t, N_0, N_1 in my loss function which would be used to learn the weights (i.e. I wish to take the derivative N_1 of N_0 w.r.t. t keeping the weights of the network fixed). Any help would be appreciated! Thanks!

I believe you need to specify that t requires grad rather than N_0 which is not a leaf variable as you have observed:

import torch

c = torch.nn.Conv2d(1,1,1,1)
i = torch.randn(1,1,1,1, requires_grad=True)
o = c(i).sum()
grad = torch.autograd.grad(o, i)
print(grad)
(tensor([[[[-0.2108]]]]),)

Thanks! As you suggested, I made the following change:

    def forward(self, t):
        t.requires_grad = True
        N_0 = self.net(t)
        ones = torch.ones_like(N_0)
        N_1 = torch.autograd.grad(N_0, t, ones, retain_graph=True)[0]
        return t, N_0, N_1    

but now I get the error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

in the line

N_1 = torch.autograd.grad(N_0, t, ones, retain_graph=True)[0]

What is being done in self.net? I cannot reproduce the code as written.

I may not be able to reveal the entire code but self.net consists of the layers of the network.

I have a feeling it may be one particular section of my code that could be causing this error (I may be wrong though). In one of the layers in the network, I am initialising the weights of the linear part (before applying the activations) as follows:

        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)

Could using torch.no_grad() be the issue?

Hi @Yash_Ramani

You could also have a look at torch.func, docs here: torch.func — PyTorch 2.0 documentation

That would allow for you to compute the output and gradients at the same time, but its syntax is a bit different. Here’s an example.

net = Model(*args, **kwargs)
params = dict(net.named_parameters())

def compute_output(params, inputs):
  outputs = torch.func.functional_call(net, params, inputs)
  return outputs, outputs 

def grad_and_value(params, inputs):
  gradients, values = torch.func.grad(compute_output, has_aux=True)(inputs) #use jacrev/jacfwd for multiple outputs
  return gradients, value

#vectorize the gradient/output function
gradients, outputs = vmap(grad_and_value, in_dims=(None, 0))(params, inputs)

Hi @AlphaBetaGamma96!

Thank you for your answer. I think I have understood the code you provided and have used it in my context as such:

    def compute_output(self, inputs):
      outputs = torch.func.functional_call(self.net, self.params, inputs)
      return outputs, outputs 

    def grad_and_value(self, inputs):
      gradients, values = torch.func.jacrev(self.compute_output, has_aux=True)(inputs) #use jacrev/jacfwd for multiple outputs
      return gradients, values
    
    def forward(self, t):
        # t.requires_grad = True
        # N_0 = self.net(t)
        # ones = torch.ones_like(N_0)
        # N_1 = torch.autograd.grad(N_0, t, ones, retain_graph=True)[0]
        gradients, outputs = torch.vmap(self.grad_and_value, in_dims=(None, 0))(t)
        return t, outputs, gradients 

where

self.params = dict(self.net.named_parameters())

However, I get the error

ValueError: vmap(grad_and_value, in_dims=(None, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*,
                       *]) but inputs has structure TreeSpec(tuple, None, [*]).

in the line

gradients, outputs = torch.vmap(self.grad_and_value, in_dims=(None, 0))(t)

I also get the same error when using torch.func.grad

Make sure to have a minimal reproducible example, and check that you have an nn.Module object.