@Mark_Hanslip Glad you’re trying the native API! The full import paths are torch.cuda.amp.autocast
and torch.cuda.amp.GradScaler
. Often, for brevity, usage snippets don’t show full import paths, silently assuming the names were imported earlier and that you skimmed the class or function declaration/header to obtain each path. For example, a snippet that shows
@autocast()
def forward...
silently assumes you wrote from torch.cuda.amp import autocast
earlier in the script.
Try from torch.cuda.amp import autocast
at the top of your script, or alternatively
@torch.cuda.amp.autocast()
def forward...
and treat GradScaler the same way.
The implicit-import-for-brevity-in-code-snippets is common practice throughout Pytorch docs, but may not be obvious if you’re relatively new to them.
A separate concern is that the loss computation(s), in addition to the forward() methods, should run under autocast (for which you could use the context-manager option with autocast()
).
The multi-model example is likely relevant as well. (retain_graph
in the example has nothing to do with Amp, it’s present so the non-Amp parts shown are functionally correct, so ignore retain_graph
.)