float16
can easily overflow if you are using values with a value close to the min. and max. values:
torch.finfo(torch.float16).max
> 65504.0
E.g. this code snippet overflows in the second approach and yields Infs
in the result after applying torch.mm
:
x = torch.randn(1024, 1024, device='cuda').half()
y = torch.mm(x, x)
print(torch.isfinite(y).all())
> tensor(True, device='cuda:0')
x = torch.randn(1024, 1024, device='cuda').half() * 2**13
print(torch.isfinite(x).all())
> tensor(True, device='cuda:0')
y = torch.mm(x, x)
print(torch.isfinite(y).all())
> tensor(False, device='cuda:0')
Since the grid
would contain valid values in [-1, 1]
, this would usually not be a problem.
However, it seems you are expecting to work with large values (which would then use the padding values), so you might disable autocast
for the grid
sampling operation and the grid
creation.