Calculate derivative of function

Suppose I have a function def f(x): return torch.sin(x)

I can find the derivative of the function at a point as folllows:

x = torch.tensor([2.], requires_grad=True)
y = f(x)
y.backward(retain_graph=True)
x.grad

However, this requires computing the value of f(x) in order to find it’s derivative, which should not be required

Other libraries such as jax provide grad(f) which returns the function cos(x)

I was wondering if there is any Pytorch equivalent of this function?

I tried using y.grad_fn in conjunction with next_functions in order to reconstruct f’(x), but this did not seem to give correct values

Any ideas?

Hi,

Pytorch is a bit particular in the sense that the definition of your function is discovered at the same time as it is evaluated at a point. So you cannot get gradient without evaluating the function at a given point. (This allows much more flexibility on what is allowed wrt control flows and inplace operations).

If you want static graphs that are differentiated symbolically (without evaluation), you can turn to TorchScript which is going to do that.

Thanks, that sounds like exactly what I’m looking for! However, after looking online for a while, I’m still confused as to how exactly to go about solving my problem using TorchScript?

You want to be able to the grad function directly?

Yes, if it’s possible

The TorchScript API for differentiation is the same as the eager mode API, so if it is not expressible in Python, it won’t be expressible in TorchScript as well. It is still possible to make something that appears like a grad function, but internally it will always compute the forward function first:

def grad(f):
    def result(x):
       # make leaf variables out of the inputs
       x_ = x.detach().requires_grad_(True) 
       f(x_).backward()
      return x_.grad
  return result

We do not do the kind of whole program transformations that mathematica or jax does that would make the generation of backwards possible.

2 Likes

@albanD What’s the current recommended way to do this?

@zdevito Is there a way to preserve gradients through that? If so, what’s the right way to do it?

To illustrate my use case (meta-optimization):

import torch

def grad(f):
    def result(x):
        x_ = x.detach().requires_grad_(True) 
        f(x_).backward()
        return x_.grad
    return result

def f(x):
    return x**2

def grad_f(x):
    g1 = 2 * x
    g2 = grad(f)(x)
    assert (g1 == g2).all()
    print('g1: {!r}\ng2: {!r}'.format(g1, g2))
    return g1

def optimize(lr):
    x = torch.tensor(1.)
    for _ in range(10**2):
        x = x - lr * grad_f(x)
    return x

lr = torch.tensor(1e-3, requires_grad=True)
opt = torch.optim.SGD([lr], 1e-6, .9)
while True:
    loss = f(optimize(lr))
    print('{:6.4f} {:6.4f}'.format(loss.item(), lr.item()))
    opt.zero_grad()
    loss.backward()
    opt.step()

Note that, after the first step, g1 is connected to the computation graph defined by lr, but g2 is not:

g1: tensor(2.)
g2: tensor(2.)
g1: tensor(1.9960, grad_fn=<MulBackward0>)
g2: tensor(1.9960)
g1: tensor(1.9920, grad_fn=<MulBackward0>)
g2: tensor(1.9920)
g1: tensor(1.9880, grad_fn=<MulBackward0>)
g2: tensor(1.9880)
g1: tensor(1.9840, grad_fn=<MulBackward0>)
g2: tensor(1.9840)

Hi,

I don’t think there is any specific API for this. Just compute the gradient the same way as you would do otherwise.
Note that while jax provides a closure for grad, it might still re-compute the forward every time to use it in some cases.

Just checking, does Pytorch have any plans to implement such static derivative functions, akin to Jax?

1 Like