Pytorch is a bit particular in the sense that the definition of your function is discovered at the same time as it is evaluated at a point. So you cannot get gradient without evaluating the function at a given point. (This allows much more flexibility on what is allowed wrt control flows and inplace operations).
If you want static graphs that are differentiated symbolically (without evaluation), you can turn to TorchScript which is going to do that.
Thanks, that sounds like exactly what I’m looking for! However, after looking online for a while, I’m still confused as to how exactly to go about solving my problem using TorchScript?
The TorchScript API for differentiation is the same as the eager mode API, so if it is not expressible in Python, it won’t be expressible in TorchScript as well. It is still possible to make something that appears like a grad function, but internally it will always compute the forward function first:
def grad(f):
def result(x):
# make leaf variables out of the inputs
x_ = x.detach().requires_grad_(True)
f(x_).backward()
return x_.grad
return result
We do not do the kind of whole program transformations that mathematica or jax does that would make the generation of backwards possible.
I don’t think there is any specific API for this. Just compute the gradient the same way as you would do otherwise.
Note that while jax provides a closure for grad, it might still re-compute the forward every time to use it in some cases.