Float32 to Bfloat16 conversion

When pytorch converts fp32 to bfloat16, does it do truncation or rounding by default?

x = torch.randn([3,4]) # fp32
x.to(torch.bfloat16) #bfloat16

I see that it has utility functions to do both but how can I find which gets triggered by default?

I think you can write a simple script to test this:

import torch

elemns = 10
a = torch.rand(elemns, dtype=torch.float32)
b = a.bfloat16()
c = b.float()
print(f"got {(c>a).sum()} elements rounded up out of {elemns}")
print(c, a)
got 6 elements rounded up out of 10
tensor([0.0253, 0.2080, 0.1826, 0.1118, 0.3809, 0.4434, 0.5742, 0.9453, 0.8789,
        0.4004]) tensor([0.0253, 0.2080, 0.1821, 0.1119, 0.3806, 0.4438, 0.5732, 0.9467, 0.8782,
        0.3997])
1 Like