Does torch.compile support gradient operation in the model?

I need to calculate the gradient of the output of network with respect to input in my model. When model is not compiled, the training is normal otherwise report an error when calculating the gradient

unsupported operand type(s) for *: ‘Tensor’ and ‘NoneType’
File “/home/yklei/practice/mlmm_energy/test/debug/mlmm/model/model_pl.py”, line 148, in
f = grad(
File “/home/yklei/practice/mlmm_energy/test/debug/mlmm/model/model_pl.py”, line 89, in forward
for key in g.ndata.keys():
File “/home/yklei/practice/mlmm_energy/test/debug/mlmm/model/model_pl.py”, line 220, in training_step
results = self(g_qmmm, cell = cell)
File “/home/yklei/practice/mlmm_energy/test/debug/mlmm_main_pl.py”, line 38, in
cli = MyLightningCLI(LitMLMM, Molecule_DataModule)#, subclass_mode_model=True)
TypeError: unsupported operand type(s) for *: ‘Tensor’ and ‘NoneType’

I write a simple code to check whether gradient operation is supported when model is compiled

import torch
def fn(x, y):
    x.requires_grad_()
    y.requires_grad_()
    a = torch.cos(x).cuda()
    b = torch.sin(y).cuda()
    fmm = torch.autograd.grad(
                    a + b,
                    [x,y],
                    grad_outputs=torch.ones_like(a+b),
                    create_graph=True,
                    retain_graph=True )
    return a + b, fmm
new_fn = torch.compile(fn, backend="inductor")
input_tensor = torch.randn(10000).to(device="cuda:0")
a,fmm = new_fn(input_tensor, input_tensor)
print(a)
print(fmm)

it actually reported an error

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

is there any way to solve this?

Late reply, but:

(1) Any chance you’ve tried this on a more recent nightly? I believe this should work now

(2) The high level problem is that torch.compile works best when capturing large backward graphs, but suffers a bit when you want to do more fine-grained surgery on the backward graph. In your example, you’re changing the requires-gradness of your inputs halfway through your function, which is a pain to support. So as of this PR: [dynamo] Add graph break on requires_grad_() by int3 · Pull Request #110053 · pytorch/pytorch · GitHub, we now graph break on x.requires_grad_(True)

For better performance, I’d recommend moving the requires_grad_() (and probably also the .cuda()) calls outside of the compiled region, and letting torch.compile just directly handle the compute (cos/sin) and the backwared calls.

Thank your for your replies! As you suggested, I use new version, then it works. But it seems that it doesn’t suport the module with external library such as deep graph library.