Mixed Precision Gradient Checking

I created a torch.autograd.Function that uses custom kernels in fwd and bwd that expect half precision inputs. The network as a whole is trained in mixed precision using torch.autocast in fp16 or bf16, and the forward and backward methods of the function are decorated with torch.cuda.amp.custom.{fwd,bwd}.

Are there any suggestions for running gradcheck for mixed precision functions to verify that my backwards implementation is correct?

Couldn’t you disable autocast for the sake of testing the correctness of the gradients in float64?