I encounter a problem when using the torch.amp, and i found that the half type tensor leads to the different results.
a = torch.tensor([2111., 2010., 1970., 1781., 1925., 1993., 2122., 2078., 2094., 2093., 1887., 2103., 1807., 2101., 1766., 1688., 1564., 2119., 1830., 1230.])
b = torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
a@b -> tensor(2111.)
a.half()@b.half() -> tensor(2112., dtype=torch.float16)
How to avoid this problem?
I will appreciate your help!