Is it possible to obtain the gradient function for builtin functions, such as tanh
similar to the backward
method of Function
objects? E.g. for tanh
, I would like to obtain a function that evaluates lambda x: 1 - x.tanh().square()
.
afaik, there is no way to obtain a symbolic form / python code. depending on your needs, you can generate code with sympy or fully fledged CAS software, but that’s somewhat manual approach.
Derivatives for elementary functions are available in the source, e.g. the tanh derivative is implemented here. I’m just not sure how to get at it from the python code.
In other words, I don’t have a requirement for a symbolic form or a python representation, just a black-box function that evaluates the gradient.
torch.ops.aten.tanh_backward(torch.ones(3), torch.tanh(torch.full((3,), 2.0)))
f = lambda x : x.tanh()
def df(x, grad_out=None):
x = x.detach().clone().requires_grad_()
return torch.autograd.grad(f(x),x, grad_out or torch.ones_like(x))
Neat, the tanh_backward
is what I’m looking for. Unfortunately, bindings for sin_backward
don’t exist (because it’s just cos
). But of course pytorch “knows” that cos
is the derivative of the builtin function sin
. What I’d like to do is extract that knowledge. Pseudocode below
def get_grad(builtin_function): ...
"""
Returns a function that evaluates the gradient of the
`builtin_function`. Raises `TypeError` if the function
does not have a registered gradient.
>>> get_grad(th.sin) # Returns a function.
>>> get_grad(lambda x: x.exp().sin()) # Raises TypeError.
"""
well, pytorch does code generation from tools/autograd/derivatives.yaml:
- name: sin(Tensor self) -> Tensor
self: grad * self.cos().conj()
- name: tanh(Tensor self) -> Tensor
self: tanh_backward(grad, result)
note that some backward functions of y=F(x) need x (self) and some y (result) as input
Generated Functions are in torch/csrc/autograd/generated/Functions.cpp, but I’m not sure whether their interface is exported to python.
If you want to be able to differentiate a function directly within PyTorch you’ll want to have a look at the FuncTorch library and use either grad
or jacrev
/jacfwd
commands on a function to return a functional form of the derivative.
Here’s an example,
import torch
from functorch import grad, vmap
x = torch.randn(10) #samples
def f(x):
return torch.sin(x)
grad_f = grad(f) #differentiate f(x) w.r.t x
df_dx = vmap(grad_f, in_dims=(0))(x) #vmap over samples
exact = torch.cos(x) #analytical result
print("Check: ",torch.allclose(exact, df_dx)) #compare results (returns True)