float16
has a lower numerical range than float32
and can thus easily over-/underflow.
Have a look at Half-precision floating-point format - Wikipedia to check the specifics and in particular:
They can express values in the range ±65,504, with the minimum value above 1 being 1 + 1/1024.
Which can easily overflow as seen if you manually apply cosine_similarity
taken from here:
def manual(x1_, x2_, dim, eps):
# manual approach
w12 = torch.sum(x1_ * x2_, dim)
print(w12)
w1 = torch.sum(x1_ * x1_, dim)
print(w1)
w2 = torch.sum(x2_ * x2_, dim)
print(w2)
out = (w1 * w2)
print(out)
out = out.clamp_min_(eps * eps)
print(out)
out = out.sqrt()
print(out)
n12 = out
w12.div_(n12)
print(w12.mean())
x = torch.randn(10,2048).cuda()
y = torch.randn(10,2048).cuda()
dim = 1
x1_ = x
x2_ = y
eps = 1e-8
manual(x1_, x2_, dim, eps)
print(torch.nn.CosineSimilarity()(x,y).mean())
manual(x1_.half(), x2_.half(), dim, eps)
You will see that the first call to manual
using float32
and nn.CosineSimilarity
gives the same results, but are also using intermediate values outside of the valid float16
range:
tensor([ 4.2845, -32.5558, 4.3298, 7.5555, 24.1382, 65.0509, -11.0637,
80.8742, -4.0695, -48.8819], device='cuda:0')
tensor([1968.8096, 2079.9971, 1968.8093, 2038.1968, 2118.2397, 2037.6740,
2009.8625, 2059.7161, 1933.5315, 2092.7104], device='cuda:0')
tensor([2048.4983, 2051.1265, 2075.0337, 1997.3761, 2019.7664, 1949.5663,
1937.5840, 2021.9969, 2035.3262, 2116.4380], device='cuda:0')
tensor([4033103.0000, 4266337.0000, 4085345.7500, 4071045.5000, 4278349.5000,
3972580.5000, 3894277.5000, 4164739.5000, 3935367.2500, 4429092.0000],
device='cuda:0')
tensor([4033103.0000, 4266337.0000, 4085345.7500, 4071045.5000, 4278349.5000,
3972580.5000, 3894277.5000, 4164739.5000, 3935367.2500, 4429092.0000],
device='cuda:0')
tensor([2008.2587, 2065.5112, 2021.2239, 2017.6832, 2068.4172, 1993.1333,
1973.3923, 2040.7693, 1983.7760, 2104.5408], device='cuda:0')
tensor(0.0045, device='cuda:0')
tensor(0.0045, device='cuda:0')
In float16
you would thus get invalid results:
tensor([ 4.2695, -32.5625, 4.3398, 7.5586, 24.1406, 65.0625, -11.0469,
80.8750, -4.0586, -48.9062], device='cuda:0', dtype=torch.float16)
tensor([1969., 2080., 1969., 2038., 2118., 2038., 2010., 2060., 1934., 2092.],
device='cuda:0', dtype=torch.float16)
tensor([2048., 2052., 2076., 1997., 2020., 1950., 1938., 2022., 2035., 2116.],
device='cuda:0', dtype=torch.float16)
tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf], device='cuda:0',
dtype=torch.float16)
tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf], device='cuda:0',
dtype=torch.float16)
tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf, inf], device='cuda:0',
dtype=torch.float16)
tensor(0., device='cuda:0', dtype=torch.float16)
No, I disagree as torch.cuda.amp.autocast
is used exactly for this reason: operations prone to overflows or general numerical instability are kept in float32
, while other operations are allowed to use float16
(inputs and outputs) while the computation itself is often still done in float32
.
Since you are manually using half()
on your data without depending on autocast
, you would need to check the used operations and see if you might be running into numerical issues.
However, I agree that training models directly in float16
is definitely not straightforward and might not work.