Following Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch nightly documentation
I would like to implement hessian vector products using torch.func module
Python: 3.10.2
PyTorch: 2.0.1
Here is a minimal working example for what I would like to do
import torch
dimensions = (3, 1)
num_data = 10
x = torch.rand(num_data, dimensions[0])
y = torch.rand(num_data, dimensions[1])
model = torch.nn.Linear(*dimensions)
vec = {name: torch.rand(*p.shape) for name, p in model.named_parameters() if p.requires_grad}
params = dict(model.named_parameters())
def model_func(new_params):
outputs = torch.func.functional_call(model, new_params, (x,), strict=True)
return torch.nn.functional.mse_loss(outputs, y)
def hvp_revrev(f, primals, tangents):
_, vjp_fn = torch.func.vjp(torch.func.grad(f), *primals)
return vjp_fn(*tangents)
# this works
hess_rev_vp = hvp_revrev(model_func, (params, ), (vec, ))[0]
# raises error
hess_vp = torch.func.jvp(torch.func.grad(model_func), (params, ), (vec,))
while the reverse only mode works well, the forward-reverse mode raises the following error
Traceback (most recent call last):
File "test_hvp.py", line 26, in <module>
hess_vp = torch.func.jvp(torch.func.grad(model_func), (params, ), (vec,))
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 916, in jvp
return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux)
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 965, in _jvp_with_argnums
result_duals = func(*duals)
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 1380, in wrapper
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
return f(*args, **kwargs)
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 1267, in wrapper
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 113, in _autograd_grad
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 303, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone() if you want a mutable tensor.
Is this expected and if so, could you please explain this message to me? Thank you:)