I’m afraid we don’t have a good solution for this atm
When vmap is out, you will be able to use that and vmap over the v argument.
If you’re happy with a hacky way, I can offer that you take the function from the autograd/functional.py file here and modify it as follows:
import torch
from torch.autograd.functional import _as_tuple, _grad_preprocess, _check_requires_grad, _validate_v, _autograd_grad, _fill_in_zeros, _grad_postprocess, _tuple_postprocess
def fw_linearize(func, inputs, create_graph=False, strict=False):
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp")
_check_requires_grad(outputs, "outputs", strict=strict)
# The backward is linear so the value of grad_outputs is not important as
# it won't appear in the double backward graph. We only need to ensure that
# it does not contain inf or nan.
grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs)
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
def lin_fn(v, retain_graph=True):
if v is not None:
_, v = _as_tuple(v, "v", "jvp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, inputs, is_inputs_tuple)
else:
if len(inputs) != 1 or inputs[0].nelement() != 1:
raise RuntimeError("The vector v can only be None if the input to "
"the user-provided function is a single Tensor "
"with a single element.")
grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph, retain_graph=retain_graph)
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
# Cleanup objects and return them to the user
jvp = _grad_postprocess(jvp, create_graph)
return _tuple_postprocess(jvp, is_outputs_tuple)
return lin_fn
def my_fun(x):
return x.pow(3).sum()
inp = torch.ones(4)
lin = fw_linearize(my_fun, inp)
v = torch.zeros(4)
print(lin(v))
v = torch.ones(4)
print(lin(v))
v = torch.tensor([1., 0., 0., 0.])
print(lin(v))
v = torch.tensor([0., 1., 0., 0.])
print(lin(v))
v = torch.tensor([0., 0., 1., 0.])
print(lin(v))
v = torch.tensor([0., 0., 0., 1.])
print(lin(v))