When pytorch converts fp32 to bfloat16, does it do truncation or rounding by default?
x = torch.randn([3,4]) # fp32
x.to(torch.bfloat16) #bfloat16
I see that it has utility functions to do both but how can I find which gets triggered by default?
When pytorch converts fp32 to bfloat16, does it do truncation or rounding by default?
x = torch.randn([3,4]) # fp32
x.to(torch.bfloat16) #bfloat16
I see that it has utility functions to do both but how can I find which gets triggered by default?
I think you can write a simple script to test this:
import torch
elemns = 10
a = torch.rand(elemns, dtype=torch.float32)
b = a.bfloat16()
c = b.float()
print(f"got {(c>a).sum()} elements rounded up out of {elemns}")
print(c, a)
got 6 elements rounded up out of 10
tensor([0.0253, 0.2080, 0.1826, 0.1118, 0.3809, 0.4434, 0.5742, 0.9453, 0.8789,
0.4004]) tensor([0.0253, 0.2080, 0.1821, 0.1119, 0.3806, 0.4438, 0.5732, 0.9467, 0.8782,
0.3997])