Can jacfwd track the variable after calling another function

For example

r    = foo(unknown)
J    = jacfwd(foo)(unknown)

def foo (unknown): 
r=torch.zeros(1, requires_grad=False, dtype=torch.float64)
a1= unknown[::2]
a2= unknown[1::2]
b1=nestfunction(a1)
r=b1+a2
return r 


def nestfunction (a1): 
b1=a1**2
return b1

Can jacfwd track a1 since a1 is passed to another function?
Thank you guys.

Autograd only operates at the aten operator level - it doesn’t know about what happens at the python level, so everything should work as if the function were inlined.