Does the dtype used for autograd’s backpropagation depend on the dtype used for forward function?

I meet fp16 overflow in forward, so i change some module(not all) to fp32, it works well but slow. I try to modify fp32 to bf16 in order to speedup it, but it cause the fp16 overflow in backward.
I don’t quick understand autograd mechnism, why I change the dtype used for forward function will cause the fp16 overflow in backward? Does the dtype used for autograd’s backpropagation depend on the dtype used for forward function? Or just the precision change make the loss and grad change for backward?

The backward pass will match the dtypes used in the forward pass. I.e. if you are using bfloat16 in the forward, the backward will also use bfloat16 and not float16.

Thanks for your answer, which solved my confusion.