Hi,
I’m trying to speed up my code, which involves computing derivatives. e.g.
import torch
model = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.Tanh(),
torch.nn.Linear(10, 1)
)
f = lambda x: (4* torch.pi**2)* torch.sin(2* torch.pi* x)
x = torch.linspace(0, 1, 10).reshape(-1, 1).requires_grad_()
def res(x):
u = model(x)
u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph = True)[0]
u_xx = torch.autograd.grad(u_x, x, torch.ones_like(u_x), create_graph = True)[0]
return u_xx + f(x)
I expected torch.compile
to speed up the computation, but when I add @torch.compile
before def res
, the function fails. I haven’t found examples of compiling functions that involve derivatives with autograd
, so I’m unsure why this happens.
I did a small test using torch.func
instead of torch.autograd
for derivative calculations, and it seems to work. However, I suspect that using autograd
directly might be faster, as it avoids multiple function calls (grad grad
, vmap
, and functional calls of the model).
Does anyone know if torch.compile
can work efficiently with autograd
. Or is torch.func
the only viable approach?
Thanks,