Torch AMP autocast and recompute

Hi, I would like to understand how torch.autocast() combines with recompute (torch.utils.checkpoint).
In official documents, loss.backward() is outside regions of with torch.autocast(xxx):, where the re-forward will happen during backward without auto-casting. So actually, it will still use all fp32 instead of fp16 as well as gradient computation, is my understanding correct?