Pass keyword arguments into jacrev's func

I tried to reproduce the minimal example from torch.func.jacrev’s documentation, but with a small change in the function g (where it’s given a keyword):

import torch

x = torch.randn(5)

def f(x):
    return x.sin()

def g(x, constant=1):
    result = f(x) * constant
    return result, result

jacobian_f, f_x = torch.func.jacrev(g, has_aux=True)(x, constant=2)

Running this gives me the following error:

TypeError: g() got an unexpected keyword argument 'constant'

What I am doing wrong? If I don’t specify constant=2 and just pass in 2, then it works, but I want to specify keyword args (for readability and for when I vmap around this).

Hi @skumar_ml,

I don’t believe you need to pass kwargs to the input of jacrev (or any other torch.func function), as I think torch.func expects just the inputs.

Also, you should wrap you torch.func.jacrev call within a torch.func.vmap call otherwise you’ll be computing all the off-diagonals, which are zero by definition.

Here’s an example,

import torch
from torch.func import jacrev, vmap

x = torch.randn(5)

def f(x):
    return x.sin()

def g(x, constant=1):
    result = f(x) * constant
    return result, result

jacobian_f, f_x = vmap(jacrev(g, argnums=(0,), has_aux=True), in_dims=(0,None))(x, 2)
print(jacobian_f, f_x)
#returns: 
(tensor([ 0.0822,  1.5112, -0.8985,  1.6367, -0.7803]),) tensor([ 1.9983, -1.3101, -1.7868,  1.1494,  1.8415])

#Without 'vmap' returns: 
(tensor([[ 1.7465,  0.0000,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  1.9095,  0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  1.6745,  0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.9700, -0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.6554]]),) tensor([ 0.9746, -0.5948,  1.0937,  0.3451,  1.1223])