NaNs in torch.nn.functional.grid_sample with Mixed Precision

float16 can easily overflow if you are using values with a value close to the min. and max. values:

torch.finfo(torch.float16).max
> 65504.0

E.g. this code snippet overflows in the second approach and yields Infs in the result after applying torch.mm:

x = torch.randn(1024, 1024, device='cuda').half()
y = torch.mm(x, x)
print(torch.isfinite(y).all())
> tensor(True, device='cuda:0')

x = torch.randn(1024, 1024, device='cuda').half() * 2**13
print(torch.isfinite(x).all())
> tensor(True, device='cuda:0')

y = torch.mm(x, x)
print(torch.isfinite(y).all())
> tensor(False, device='cuda:0')

Since the grid would contain valid values in [-1, 1], this would usually not be a problem.
However, it seems you are expecting to work with large values (which would then use the padding values), so you might disable autocast for the grid sampling operation and the grid creation.

1 Like