Do we need to do torch.cuda.amp.autocast(enabled=False) before a custom function?

We have some code that is designed to run in mixed-precision. It contains a custom CUDA kernel call. Currently we are placing a with torch.cuda.amp.autocast(enabled=False) guard around the custom CUDA kernel call, and then cast all tensors to the same type. Is the autocast(enabled=False) really necessary? What might happen without this block? (context: this pull request, that proposes to move the casts from out of the autocast guard block, to enable compatibility with APEX, Enable both Pytorch native AMP and Nvidia APEX AMP for SRU by visionscaper · Pull Request #158 · asappresearch/sru · GitHub )

If the function accepts arguments in any precision, and doesn’t even care if they match, you probably don’t need to do anything.

If you’re wrapping the op with a torch.autograd.Function:
If the kernel needs matching input dtypes, but those dtypes can be anything as long as they match, the approach you describe is probably the best.
If the cuda kernel needs a particular dtype, you can do it a bit more conveniently.

If you’re registering the op with torch on the C++ side, see the autocast section of the dispatch registration tutorial.

autocast(enabled=False) does not affect apex amp. torch.cuda.amp is meant as the permanent and complete replacement, so you should prefer torch.cuda.amp.


Ok. Thank you for the very useful information. Question: what might go wrong if we simply cast everything to the same type (16 vs 32bits) as one of the tensors, but without using an autocast(enabled=False) guard around this?

If the forward pass of your op consists only of the custom kernel itself, that’s fine. Autocast doesn’t know about your kernel (unless you register it like in the dispatch tutorial) and won’t touch the inputs.

If your op consists of your custom kernel + a few torch.* ops, and you don’t locally autocast(enabled=False), the torch.* ops still might be affected by autocast, which you may or may not want.

For torch.autograd.Functions, the @custom_fwd(cast_inputs=torch.float32/16) takes care of casting inputs and disabling autocast for the forward body.