I tried to reproduce the minimal example from torch.func.jacrev
’s documentation, but with a small change in the function g (where it’s given a keyword):
import torch
x = torch.randn(5)
def f(x):
return x.sin()
def g(x, constant=1):
result = f(x) * constant
return result, result
jacobian_f, f_x = torch.func.jacrev(g, has_aux=True)(x, constant=2)
Running this gives me the following error:
TypeError: g() got an unexpected keyword argument 'constant'
What I am doing wrong? If I don’t specify constant=2
and just pass in 2
, then it works, but I want to specify keyword args (for readability and for when I vmap around this).