Torch.compile interaction with autocast and no_grad

Where should I put no_grad and autocast contexts? Inside the function that will be torch.compile’d? Or should I torch.compile first and then wrap the calls with autocast / no_grad?

In theory knowing the autocast context and if we don’t need to save intermediate tensors can lead to stonger optimization and more aggressive inplace / economic memory placement

Today, if you try to torch.compile a module / function that internally uses autocast/no_grad context managers, dynamo will graph break on them.

So I’d recommend putting them outside for now:

@torch.compile
def f(args):
    ...

with torch.cuda.amp.autocast():
    out = f(...)

with torch.no_grad():
    out = f(...)

I think we want to fix this though, and avoid graph breaking on these context managers. So longer term the answer is “it shouldn’t matter” - feel free to file an issue though!

1 Like

Thank you! Created a sister discussion in Interaction of torch.no_grad and torch.autocast context managers with torch.compile · Issue #100241 · pytorch/pytorch · GitHub