Where should I put no_grad and autocast contexts? Inside the function that will be torch.compile’d? Or should I torch.compile first and then wrap the calls with autocast / no_grad?
In theory knowing the autocast context and if we don’t need to save intermediate tensors can lead to stonger optimization and more aggressive inplace / economic memory placement
Today, if you try to torch.compile a module / function that internally uses autocast/no_grad context managers, dynamo will graph break on them.
So I’d recommend putting them outside for now:
@torch.compile
def f(args):
...
with torch.cuda.amp.autocast():
out = f(...)
with torch.no_grad():
out = f(...)
I think we want to fix this though, and avoid graph breaking on these context managers. So longer term the answer is “it shouldn’t matter” - feel free to file an issue though!