Can I use grad around a numpy code and compile it?

Hi,

I have a set of pure numpy codes, and I want to calculate the gradient.
Can I do something like:

Z = grad(numpy_func)(x)

Since torch.compile can understand numpy code, and also it can work with grad(Frequently Asked Questions — PyTorch main documentation). I wonder if I can use grad(and other func) on a numpy function?

Hi @Roy-Kid,

In order to compute gradients of a function within the pytorch framework, you need to 1) Use torch functions so that autograd can track operations or 2) Define a custom torch.autograd.Function and manually define the backward pass of your function so autograd can compute gradients.

The docs for torch.autograd.Function objects can be found here: Automatic differentiation package - torch.autograd — PyTorch 2.3 documentation

Hi,

I want to use torch.func.grad to calculate the gradient not autograd. The numpy function can be compiled to torch code; I don’t think I need to implement a backward process manually?

I still think you need to define the backward manually, can you take the grad of a compiled function from numpy?

You could always try doing torch.func.grad(compiled_fn) and seeing how torch.func works, but I feel it won’t work as torch.func works on pytorch primitives.

I hope so… According to Frequently Asked Questions — PyTorch main documentation, it can use torch.func.grad and compile together(I test it in 2.4.0-dev).

What if I want to do following:

import torch

@torch.compile
def f(x):
    return np.sin(x)

def g(x):
    return torch.func.grad(f)(x)

x = torch.randn(2, 3)
g(x)

The error raises: RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead. So I want to ask if it not support yet or I can not do it at all.

Can you share a minimal reproducible example with your function?

Also, from the example you shared it only uses torch functions,

import torch

@torch.compile
def f(x):
    return torch.sin(x)

def g(x):
    return torch.grad(f)(x)

x = torch.randn(2, 3)
g(x)

What version of pytorch are you using? The docs state it needs to be 2.1+

This one is tested under 2.4.0-dev

import torch
import numpy as np
from torch import compile
from torch.func import grad

def fn(x):
    return np.power(x, 2).sum()

cfn = compile(fn)
gcfn = grad(compile(fn))
cgfn = compile(grad(fn))
cgcfn = compile(grad(compile(fn)))

arr = torch.tensor([1., 2., 3.])
print(fn(arr))
print(cfn(arr))
# print(gcfn(arr))  error
# print(cgfn(arr))  error
print(cgcfn(arr))  # wrong

So, perhaps you might have to pass a full_graph=True flag to the compiling statement, in the docs, it has this following example,

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda", requires_grad=True)
Y = torch.randn(1024, 64, device="cuda")
Z = numpy_fn(X, Y)
assert isinstance(Z, torch.Tensor)
Z.backward()
# X.grad now holds the gradient of the computation
print(X.grad)

Although this snippet can not run in 2.4.0-dev(I think it also can not run in a stable version because numpy_fn’s return is not scalar), it looks more reasonable.

I need to use torch.func.grad, not Tensor.grad, because it is more convenient and easy to calculate high-order differences. I think it should work but it does not:

@torch.compile(fullgraph=True)
@torch.compiler.wrap_numpy
def numpy_fn(X, Y):
    return np.sum(X[:, :, None] * Y[:, None, :], axis=(-2, -1))

X = torch.randn(1024, 64, device="cuda")
Y = torch.randn(1024, 64, device="cuda")
Z = torch.func.grad(numpy_fn)(X, Y)  # use func.grad not Tensor.grad, since may also use jvp etc.
assert isinstance(Z, torch.Tensor)
assert Z.device.type == "cuda"

Perhaps torch.func can only handle pytorch primitives, whereas the .grad approach works fine? Perhaps it might be best to get a dev’s opinion: @vfdev-5 (apologies for the tag)

When running the script the following error message I get is,

torch._dynamo.exc.Unsupported: torch.func.grad(fn) requires the function to be inlined by dynamo

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

To be able to use torch.func over NumPy, or compiled functions in general, the relevant torch.func calls should also be within torch.compile.

As a side note, fullgraph=True does not really affect this whole process, it has no side-effects. It’s there to make sure we just get one graph. wrap_numpy, is to be used within torch.compile.

Taking all this into account, we get this script, which does compile and does what you’d expect:

import numpy as np
import torch
from torch.func import grad

@torch.compiler.wrap_numpy
def fn(x):
    return np.power(x, 2).sum()

@torch.compile
def my_grad(x):
    return torch.func.grad(fn)(x)

arr = torch.tensor([1., 2., 3.], requires_grad=True)
print(my_grad(arr))

The snipped from the docs from Can I use grad around a numpy code and compile it? - #9 by Roy-Kid should work. I guess we had a small regression there. I’ll submit a fix today or tomorrow.

1 Like

@Lezcano @AlphaBetaGamma96 Thanks a lot for your patient help! I have always thought this strategy works, but only a slight bug exists. I have a lot of numpy energy calculators, and I want to reuse them in the NNP package, which is why I am asking this ridiculous question. After it is fixed, I will try again!

1 Like

It turns out that the example was just wrong. I was calling backward() on a tensor, not a scalar. I sent a fix at [Docs] Fix NumPy + backward example by lezcano · Pull Request #126872 · pytorch/pytorch · GitHub.
torch.compile already works as advertised to compute gradients in NumPy code.

Note that the usual caveats of minimising graph breaks and avoiding explicit loops / data dependent stuff apply.