# How can I calculate the grad in func jacrev()

Hi, everyone. Here is a function to calculate the Neural Tangent Kernel (NTK)

``````def ntk(model, x):
from torch.func import functional_call, vmap, jacrev

params = dict(model.named_parameters())

def fnet_single_torch(params, x):
y = functional_call(model, params, x.unsqueeze(0)).squeeze(0)
return y

jac1 = vmap(jacrev(fnet_single_torch), (None, 0))(params, x)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]

jac2 = vmap(jacrev(fnet_single_torch), (None, 0))(params, x)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]

# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == "full":
einsum_expr = "Naf,Mbf->NMab"
elif compute == "trace":
einsum_expr = "Naf,Maf->NM"
elif compute == "diagonal":
einsum_expr = "Naf,Maf->NMa"
else:
assert False

result = torch.stack(
[torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]
)
result = result.sum(0)
return result.squeeze()
``````

Now, I need to add a transformation to the network output and calculate the gradient term in this transformation. How can I modify this code?

I have tried adding transformations directly to `fnet_single_torch` function like below, but I noticed that when using function `jacrev()`, the incoming `x` will not require grad

``````def output_transform(x,y):

def fnet_single_torch(params, x):
y = functional_call(model, params, x.unsqueeze(0)).squeeze(0)
y = output_transform(x,y)
return y
``````

If you want a function that computes the derivatives of your model output with respect to the inputs within the `torch.func` namespace, you can just use the `torch.func.grad` function,

``````def f(params, x);
return functional_call(model, params, x.unsqueeze(0)).squeeze(0)

``````grad_f = torch.func.vmap(torch.func.grad(f, argnums=(0)), in_dims=(None,0))(params, x)