Is it possible to look up the gradient of native functions?

Is it possible to obtain the gradient function for builtin functions, such as tanh similar to the backward method of Function objects? E.g. for tanh, I would like to obtain a function that evaluates lambda x: 1 - x.tanh().square().

afaik, there is no way to obtain a symbolic form / python code. depending on your needs, you can generate code with sympy or fully fledged CAS software, but that’s somewhat manual approach.

Derivatives for elementary functions are available in the source, e.g. the tanh derivative is implemented here. I’m just not sure how to get at it from the python code.

In other words, I don’t have a requirement for a symbolic form or a python representation, just a black-box function that evaluates the gradient.

torch.ops.aten.tanh_backward(torch.ones(3), torch.tanh(torch.full((3,), 2.0)))
f = lambda x : x.tanh()
def df(x, grad_out=None):
	x = x.detach().clone().requires_grad_()
	return torch.autograd.grad(f(x),x, grad_out or torch.ones_like(x))

Neat, the tanh_backward is what I’m looking for. Unfortunately, bindings for sin_backward don’t exist (because it’s just cos). But of course pytorch “knows” that cos is the derivative of the builtin function sin. What I’d like to do is extract that knowledge. Pseudocode below

def get_grad(builtin_function): ...
    Returns a function that evaluates the gradient of the 
    `builtin_function`. Raises `TypeError` if the function
    does not have a registered gradient.

    >>> get_grad(th.sin)  # Returns a function.
    >>> get_grad(lambda x: x.exp().sin())  # Raises TypeError.

well, pytorch does code generation from tools/autograd/derivatives.yaml:

- name: sin(Tensor self) -> Tensor
  self: grad * self.cos().conj()

- name: tanh(Tensor self) -> Tensor
  self: tanh_backward(grad, result)

note that some backward functions of y=F(x) need x (self) and some y (result) as input

Generated Functions are in torch/csrc/autograd/generated/Functions.cpp, but I’m not sure whether their interface is exported to python.

If you want to be able to differentiate a function directly within PyTorch you’ll want to have a look at the FuncTorch library and use either grad or jacrev/jacfwd commands on a function to return a functional form of the derivative.

Here’s an example,

import torch
from functorch import grad, vmap

x = torch.randn(10) #samples

def f(x):
    return torch.sin(x)

grad_f = grad(f)                     #differentiate f(x) w.r.t x
df_dx = vmap(grad_f, in_dims=(0))(x) #vmap over samples 

exact = torch.cos(x) #analytical result

print("Check: ",torch.allclose(exact, df_dx)) #compare results (returns True)