Make_fx and vjp

The functorch README has a nice example of source-to-source AD for grad (which also works nicely for jacrev)

from functorch import make_fx, jacrev
def f(x):
    return torch.sin(x).sum()
x = torch.randn(100)
grad_f = make_fx(jacrev(f))(x)

I’m trying to make it work for VJP, but can’t quite get the incantation. I would have thought this should work

from functorch import make_fx, grad, vjp, jacrev
def f(x):
    return torch.sin(x)
x = torch.randn(10)
dret = torch.randn(10)
grad_f = make_fx(vjp(f,x)[1])(dret)

But it throws

RuntimeError: Tracing expected 3 arguments but got 1 concrete arguments

I guess I can see why it might not work, but wonder if there’s a workaround.