How to enable grad calculation when using stateless.functional_call

Torch 1.12 introduced the wonderful feature: torch.nn.utils.stateless.functional_call(), whose syntax is as follows:
out = functional_call(model, params, inp)

Namely, it allows to evaluate the NN with any parameters. However, how can I calculate the gradient with respect to the evaluated parameters of the output? It seems that it doesn’t save anything to allow backward propagation. Is there a way to enable this? or is this not possible with functional_call and I need to go back to the traditional way of just creating a model per set of parameters I want to evaluate?

To be a little more precise, let me add some code:

params = [params1, params2, params3]
inputs = [inp1, inp2, inp3]

grads = []

for param, inp in zip(params, inputs):
     out = torch.nn.utils.stateless.functional_call(model, param, inp)
     grad = autograd.grad(out, param)
     grads.append(grad)

Clearly, that won’t work, but I hope that illustrates what I am trying to do.

I’m not sure what the issue is, but it seems to work for me:

model = nn.Linear(10, 10, bias=False)

param = {"weight": nn.Parameter(torch.ones(10, 10))}
inp = torch.randn(1, 10)

out = torch.nn.utils.stateless.functional_call(model, param, inp)
grad = torch.autograd.grad(out.mean(), param['weight'])
print(grad)
1 Like

Yes, that works. I found the problem with my code is that I was doing the following:

model = nn.Linear(10, 10)

param = dict(model.state_dict())
inp = torch.rand(1, 10)

out = torch.nn.utils.stateless.functional_call(model, param, inp)

Then when I try to evaluate the grad I get the error: RuntimeError: element 0 of tensors does not require grad

With that solved, I have another question. Let’s say the model has a bias, then it looks like this:

model = nn.Linear(10, 10)

param = {"weight": nn.Parameter(torch.ones(10, 10)),
         "bias": nn.Parameter(torch.ones(10))}

inp = torch.randn(1, 10)

out = torch.nn.utils.stateless.functional_call(model, param, inp)
grad = torch.autograd.grad(out.mean(), param['weight'])
print(grad)

grad will only compute the gradient with respect to the weights and not the bias. What argument should I pass, instead of param['weight'], to compute the gradient with respect to all parameters? I tried param.items(), but that didn’t work.

You can pass a sequence of tensors to autograd.grad:

grad = torch.autograd.grad(out.mean(), [param['weight'], param['bias']])
print(grad)
1 Like