Error 'Output 0 is independent of input 0' happens while using jacobian of a function that the output changes in my demo with different input

I am trying to use torch.autograd.functional.jacobian to calculate the gradients of the model parameters with respect to a set of losses, but it returns all zeros. My function definition is as follows

def load_weights(model, names, orig_params, new_params, as_params=False):
    param_shapes = [p.shape for p in model.dnn.parameters()]

    start = 0
    for name, p, new_p, shape in zip(names, orig_params, new_params, param_shapes):
        numel = int(torch.prod(torch.tensor(shape)))
        set_attr(model.dnn, name.split("."), torch.nn.Parameter(new_params[start:start + numel].view(shape)))
        start += numel

def func(param_list):
    load_weights(self.model, names, org_param, param_list, 1)
    result = self.eq_cons(param_list)
    result.requires_grad_()
    return result

jac_mtx = torch.autograd.functional.jacobian(func, param_list, strict=1)

The eq_cons function updates the model using the passed parameters, computes a series of loss, and concatenates them into an array with a shape consistent with specific data. When I modify param_list, it can be seen that the return value of the eq_cons and the func function both changes.

However, when I use torch.autograd.functional.jacobian to compute the Jacobian of func with respect to param_list, I find that the resulting Jacobian matrix is all zeros. When I set strict=1, it raises the error:

RuntimeError: Output 0 of the user-provided function is independent of input 0. This is not allowed in strict mode.

To solve this problem, I check the model.parameters() in the end of the load_weights function and I am sure that the parameters are updated. Besides, I attempted to split the output, call backward in a loop, and used the following code to calculate the gradients separately. However, it returned the same result for different terms in the output.

for index in range(len(output)): 
    self.model.dnn.zero_grad() 
    item = output[index]
    item.backward()         
    for p in self.model.dnn.parameters():
        param_grad = p.grad.detach().data

I am confused about this error because the output of func changes with the input, and I have no idea about how to solve it. Why does this error occur? How can I modify my code to compute the jacobian correctly?

Hi Catfish!

This looks very suspicious. The fact that you call result.requires_grad_() suggests that
self.eq_cons() doesn’t properly build and / or maintain the computation graph (that is used
autograd to compute the jacobian) and that you’re trying to fix the issue after the fact. But once
the computation graph is broken, .requires_grad_() can’t fix it.

Make sure that you can backward successfully through func() or otherwise check the integrity
of the computation graph that gets built when you run it.

Consider this illustrative script:

import torch
print (torch.__version__)

def f (x):
    xsq = x * x
    return  xsq

def f_break (x):
    with torch.no_grad():   # breaks the computation graph
        xsq = x * x
    return  xsq

def f_break_req (x):
    with torch.no_grad():
        xsq = x * x
    xsq.requires_grad_()    # doesn't fix the problem -- graph still broken
    return  xsq

x = torch.tensor ([2., 3.])

jac = torch.autograd.functional.jacobian (f, x)
print ('jac = ...')
print (jac)

jac_break = torch.autograd.functional.jacobian (f_break, x)
print ('jac_break = ...')
print (jac_break)

jac_break_req = torch.autograd.functional.jacobian (f_break_req, x)
print ('jac_break_req = ...')
print (jac_break_req)

And its output:

2.7.0+cu128
jac = ...
tensor([[4., 0.],
        [0., 6.]])
jac_break = ...
tensor([[0., 0.],
        [0., 0.]])
jac_break_req = ...
tensor([[0., 0.],
        [0., 0.]])

Best.

K. Frank

Thank you for your suggestion! I check my eq_cons and realize that I used the detach function, which may affect the calculation of the computation graph. I delete it and get the correct result.