Torch 1.12 introduced the wonderful feature: torch.nn.utils.stateless.functional_call(), whose syntax is as follows: out = functional_call(model, params, inp)
Namely, it allows to evaluate the NN with any parameters. However, how can I calculate the gradient with respect to the evaluated parameters of the output? It seems that it doesn’t save anything to allow backward propagation. Is there a way to enable this? or is this not possible with functional_call and I need to go back to the traditional way of just creating a model per set of parameters I want to evaluate?
To be a little more precise, let me add some code:
params = [params1, params2, params3]
inputs = [inp1, inp2, inp3]
grads = []
for param, inp in zip(params, inputs):
out = torch.nn.utils.stateless.functional_call(model, param, inp)
grad = autograd.grad(out, param)
grads.append(grad)
Clearly, that won’t work, but I hope that illustrates what I am trying to do.
Yes, that works. I found the problem with my code is that I was doing the following:
model = nn.Linear(10, 10)
param = dict(model.state_dict())
inp = torch.rand(1, 10)
out = torch.nn.utils.stateless.functional_call(model, param, inp)
Then when I try to evaluate the grad I get the error: RuntimeError: element 0 of tensors does not require grad
With that solved, I have another question. Let’s say the model has a bias, then it looks like this:
model = nn.Linear(10, 10)
param = {"weight": nn.Parameter(torch.ones(10, 10)),
"bias": nn.Parameter(torch.ones(10))}
inp = torch.randn(1, 10)
out = torch.nn.utils.stateless.functional_call(model, param, inp)
grad = torch.autograd.grad(out.mean(), param['weight'])
print(grad)
grad will only compute the gradient with respect to the weights and not the bias. What argument should I pass, instead of param['weight'], to compute the gradient with respect to all parameters? I tried param.items(), but that didn’t work.