Hessian-vector products with torch.func

Following Jacobians, Hessians, hvp, vhp, and more: composing functorch transforms — functorch nightly documentation
I would like to implement hessian vector products using torch.func module

Python: 3.10.2
PyTorch: 2.0.1

Here is a minimal working example for what I would like to do

import torch

dimensions = (3, 1)
num_data = 10

x = torch.rand(num_data, dimensions[0])
y = torch.rand(num_data, dimensions[1])

model = torch.nn.Linear(*dimensions)
vec = {name: torch.rand(*p.shape) for name, p in model.named_parameters() if p.requires_grad}
params = dict(model.named_parameters())


def model_func(new_params):
    outputs = torch.func.functional_call(model, new_params, (x,), strict=True)
    return torch.nn.functional.mse_loss(outputs, y)


def hvp_revrev(f, primals, tangents):
    _, vjp_fn = torch.func.vjp(torch.func.grad(f), *primals)
    return vjp_fn(*tangents)


# this works
hess_rev_vp = hvp_revrev(model_func, (params, ), (vec, ))[0]

# raises error
hess_vp = torch.func.jvp(torch.func.grad(model_func), (params, ), (vec,))

while the reverse only mode works well, the forward-reverse mode raises the following error

Traceback (most recent call last):
  File "test_hvp.py", line 26, in <module>
    hess_vp = torch.func.jvp(torch.func.grad(model_func), (params, ), (vec,))
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 916, in jvp
    return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux)
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 965, in _jvp_with_argnums
    result_duals = func(*duals)
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 1380, in wrapper
    results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 39, in fn
    return f(*args, **kwargs)
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 1267, in wrapper
    flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 113, in _autograd_grad
    grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
  File "/opt/homebrew/Caskroom/mambaforge/base/envs/pydvl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 303, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone() if you want a mutable tensor.

Is this expected and if so, could you please explain this message to me? Thank you:)