`compile` a function with `autograd`

Hi,
I’m trying to speed up my code, which involves computing derivatives. e.g.

import torch
model = torch.nn.Sequential(
    torch.nn.Linear(1, 10),
    torch.nn.Tanh(),
    torch.nn.Linear(10, 1)
)

f = lambda x: (4* torch.pi**2)* torch.sin(2* torch.pi* x)
x = torch.linspace(0, 1, 10).reshape(-1, 1).requires_grad_()

def res(x):
    u = model(x)
    u_x = torch.autograd.grad(u, x, torch.ones_like(u), create_graph = True)[0]
    u_xx = torch.autograd.grad(u_x, x, torch.ones_like(u_x), create_graph = True)[0]
    return u_xx + f(x)

I expected torch.compile to speed up the computation, but when I add @torch.compile before def res, the function fails. I haven’t found examples of compiling functions that involve derivatives with autograd, so I’m unsure why this happens.

I did a small test using torch.func instead of torch.autograd for derivative calculations, and it seems to work. However, I suspect that using autograd directly might be faster, as it avoids multiple function calls (grad grad, vmap, and functional calls of the model).

Does anyone know if torch.compile can work efficiently with autograd. Or is torch.func the only viable approach?

Thanks,

Are you looking for something like Compiled Autograd: Capturing a larger backward graph for torch.compile — PyTorch Tutorials 2.6.0+cu124 documentation