We have some code that is designed to run in mixed-precision. It contains a custom CUDA kernel call. Currently we are placing a with torch.cuda.amp.autocast(enabled=False)
guard around the custom CUDA kernel call, and then cast all tensors to the same type. Is the autocast(enabled=False)
really necessary? What might happen without this block? (context: this pull request, that proposes to move the casts from out of the autocast guard block, to enable compatibility with APEX, Enable both Pytorch native AMP and Nvidia APEX AMP for SRU by visionscaper · Pull Request #158 · asappresearch/sru · GitHub )
If the function accepts arguments in any precision, and doesn’t even care if they match, you probably don’t need to do anything.
If you’re wrapping the op with a torch.autograd.Function:
If the kernel needs matching input dtypes, but those dtypes can be anything as long as they match, the approach you describe is probably the best.
If the cuda kernel needs a particular dtype, you can do it a bit more conveniently.
If you’re registering the op with torch on the C++ side, see the autocast section of the dispatch registration tutorial.
autocast(enabled=False) does not affect apex amp. torch.cuda.amp is meant as the permanent and complete replacement, so you should prefer torch.cuda.amp.
Ok. Thank you for the very useful information. Question: what might go wrong if we simply cast everything to the same type (16 vs 32bits) as one of the tensors, but without using an autocast(enabled=False)
guard around this?
If the forward pass of your op consists only of the custom kernel itself, that’s fine. Autocast doesn’t know about your kernel (unless you register it like in the dispatch tutorial) and won’t touch the inputs.
If your op consists of your custom kernel + a few torch.* ops, and you don’t locally autocast(enabled=False)
, the torch.* ops still might be affected by autocast, which you may or may not want.
For torch.autograd.Functions, the @custom_fwd(cast_inputs=torch.float32/16)
takes care of casting inputs and disabling autocast for the forward body.
Could you explain more how the torch.* ops will be affected by the explicit casting if we dont do it inside a disabled autocast context? Do you mean it can lead to difference in precision in cases where output of the kernel is fp32 and a torch.mm follows it and in this case it will execute in bf16? If this is the case, then its fine, because this is not a misbehavior, its expected. What I want to understand is if there is a hard rule of “no explicit casting” inside autocast enabled context as it can lead to unexplained weird behaviors? If that is not the case, then its just that the user needs to be careful about the usage.