I trained the mode in mixed precision. in certain function, a input tensor in float16 is transformed into float32 after torch.norm
dis_vec (float16) dis (float32)
and if I use dis=torch.sqrt((dis_vec**2).sum(-1)), dis is also float32
I trained the mode in mixed precision. in certain function, a input tensor in float16 is transformed into float32 after torch.norm
autocast
uses an internal “allow-list” to cast tensors into float16
, if the operation is considered save using this precision.
The autocast
docs give you some more information.
Are you seeing any dtype
mismatch errors in your code (inside the autocast
region) as this could be a bug?
no mismatch errors. in principle, op torch.norm will not change the dtype, so why do the output tensor become float32 even though my input tensor is float16?
torch.norm
is an op, which autocasts to float32
as given in this list.
Thanks! so if i want to continue the following computation with float16, I need set it float16 manually by dis.half()?
If you want to apply the norm
using float16
values and are sure that you won’t run into numerical issues, you could disable autocast
for this operation and manually cast it to the desired type (you can use nested autocast
decorators).