Hi, everyone. Here is a function to calculate the Neural Tangent Kernel (NTK)
def ntk(model, x):
from torch.func import functional_call, vmap, jacrev
params = dict(model.named_parameters())
def fnet_single_torch(params, x):
y = functional_call(model, params, x.unsqueeze(0)).squeeze(0)
return y
jac1 = vmap(jacrev(fnet_single_torch), (None, 0))(params, x)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]
jac2 = vmap(jacrev(fnet_single_torch), (None, 0))(params, x)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == "full":
einsum_expr = "Naf,Mbf->NMab"
elif compute == "trace":
einsum_expr = "Naf,Maf->NM"
elif compute == "diagonal":
einsum_expr = "Naf,Maf->NMa"
else:
assert False
result = torch.stack(
[torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]
)
result = result.sum(0)
return result.squeeze()
Now, I need to add a transformation to the network output and calculate the gradient term in this transformation. How can I modify this code?
I have tried adding transformations directly to fnet_single_torch
function like below, but I noticed that when using function jacrev()
, the incoming x
will not require grad
def output_transform(x,y):
return autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]
def fnet_single_torch(params, x):
y = functional_call(model, params, x.unsqueeze(0)).squeeze(0)
y = output_transform(x,y)
return y