When does nan get turned into inf?

I am running two Conv2d layers on a tensor of nans and getting -infs as output.

Two questions:

  1. why?
  2. is there a way to prevent this and keep them as nans?

Hi Sam!

This is strange, and not something I would expect to happen. After all,
“nan” is the “One Floating-Point Number to rule them all.”

“nan” should infect all arithmetic and turn the results into “nan”:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> tnan = torch.tensor ([float ('nan')])
>>> tinf = torch.tensor ([float ('inf')])
>>> tnan
tensor([nan])
>>> tnan + tinf
tensor([nan])
>>> tnan * tinf
tensor([nan])
>>> tnan / 0.0
tensor([nan])

As expected, when I feed “nan” into a Conv2d, I get “nan” out:

>>> conv = torch.nn.Conv2d (1, 1, 3)
>>> t = torch.randn (1, 1, 10, 10)
>>> t[0, 0, 4, 4] = tnan
>>> conv (t)
tensor([[[[ 0.0393,  0.0779,  0.0731,  0.8591, -0.8162,  0.6578, -0.8222,
            0.8873],
          [-1.3210,  1.4260,  0.0526, -0.5711,  1.2245,  0.7201, -0.3848,
           -0.2762],
          [ 0.1857, -0.7266,     nan,     nan,     nan,  0.8714,  0.4864,
            0.2397],
          [-0.7764,  1.5293,     nan,     nan,     nan, -0.0663,  0.2233,
           -0.5896],
          [ 0.2750, -0.0982,     nan,     nan,     nan, -0.2574, -0.1529,
            0.5295],
          [ 0.7697,  1.0994, -0.3693, -0.5683, -0.4822,  0.9385, -0.7202,
            0.7361],
          [-1.3883,  0.6983,  0.3545,  0.8573, -0.0080, -0.2240,  0.2517,
           -0.4848],
          [-0.5190,  0.1873,  0.3782,  0.3108, -0.1297, -0.5012,  1.1124,
           -0.7626]]]], grad_fn=<ThnnConv2DBackward>)

Can you reproduce this with short, self-contained, runnable script?

Best.

K. Frank

Are you also running a max-pooling layer? I can see how a max-pooling implementation somewhere might start with

max = -inf

and then try to find the max of a window, but comparing anything to nan is futile, and maybe the max remains to be -inf

Hi smth!

Pytorch’s min() and max() appear to “do the right thing” with nan:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> tnan = torch.tensor ([float ('nan')])
>>> tinf = torch.tensor ([float ('inf')])
>>> torch.min (tinf, tnan)
tensor([nan])
>>> torch.max (tinf, tnan)
tensor([nan])

Also, MaxPool2d appears to work as “expected”:

>>> maxp = torch.nn.MaxPool2d (3)
>>> t = torch.randn (1, 1, 10, 10)
>>> t[0, 0, 4, 4] = tnan
>>> maxp (t)
tensor([[[[1.1035, 1.3305, 3.2121],
          [2.3521,    nan, 1.2039],
          [0.7372, 0.4652, 2.2204]]]])

Best.

K. Frank

Okay, I have a short reproduced script. It returns:

tensor(True, device='cuda:0')
tensor(False, device='cuda:0')

on cuda, however not on CPU:

import torch
from torch import nn

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.CNN = nn.Sequential(nn.Conv2d(1, 32, 3, stride=2))

    def forward(self, x):
        return self.CNN(x)

cnn = CNN().to('cuda')
x = torch.full([128, 1, 28, 28], float('nan')).to('cuda')
print(torch.isnan(x).all())
y = cnn(x)
print(torch.isnan(y).all())

On CPU, it returns

tensor(True)
tensor(True)

as expected.

An even shorter version:

import torch

cnn = torch.nn.Conv2d(1, 32, 3, stride=2).to('cuda')
x = torch.full([128, 1, 28, 28], float('nan')).to('cuda')
print(torch.isnan(x).all())
y = cnn(x)
print(torch.isnan(y).all())

Prints:

tensor(True, device='cuda:0')
tensor(False, device='cuda:0')

Hi Sam!

Okay, this is fun …

My take:

import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())

kernel = 1
stride = 1
H = 1
W = 1
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())

kernel = 3
H = 3
W = 74
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())
W = 75
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())

H = 2048
W = 74
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())
W = 75
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())

kernel = 1
stride = 2
H = 1
W = 1
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())
H = 2
W = 1
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())

H = 1
W = 2048
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())
H = 2
W = 2048
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())

kernel = 3
stride = 2
H = 10
W = 3
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())
H = 11
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())

H = 3
W = 14
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())
W = 15
res = torch.nn.Conv2d (1, 1, kernel_size = kernel, stride = stride).cuda() (torch.full ((1, 1, H, W), float ('nan'), device = 'cuda'))
print ('kernel:', kernel, ', stride:', stride, ', H:', H, ', W:', W)
print ('all nans:', torch.isnan (res).all().item(), '  all infs:', torch.isinf (res).all().item())

And its output:

1.10.0
10.2
GeForce GTX 1050 Ti
kernel: 1 , stride: 1 , H: 1 , W: 1
all nans: False   all infs: True
kernel: 3 , stride: 1 , H: 3 , W: 74
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 3 , W: 75
all nans: False   all infs: True
kernel: 3 , stride: 1 , H: 2048 , W: 74
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 2048 , W: 75
all nans: False   all infs: True
kernel: 1 , stride: 2 , H: 1 , W: 1
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 2 , W: 1
all nans: False   all infs: True
kernel: 1 , stride: 2 , H: 1 , W: 2048
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 2 , W: 2048
all nans: False   all infs: True
kernel: 3 , stride: 2 , H: 10 , W: 3
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 11 , W: 3
all nans: False   all infs: True
kernel: 3 , stride: 2 , H: 3 , W: 14
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 3 , W: 15
all nans: False   all infs: True

Technically speaking, I would say that this is a bug. @ptrblck?

Best.

K. Frank

It could be a known issue with e.g. older cuDNN versions.
Since you are using the CUDA 10.2 binaries, your cuDNN version would be 7.6.5, so could you update to the latest CUDA11 binaries?
With CUDA 11.5 and cuDNN 8.3.2 I get:

1.11.0.dev20220108+cu115
11.5
NVIDIA GeForce RTX 3090
kernel: 1 , stride: 1 , H: 1 , W: 1
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 3 , W: 74
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 3 , W: 75
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 2048 , W: 74
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 2048 , W: 75
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 1 , W: 1
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 2 , W: 1
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 1 , W: 2048
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 2 , W: 2048
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 10 , W: 3
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 11 , W: 3
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 3 , W: 14
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 3 , W: 15
all nans: True   all infs: False
1 Like

Hi @ptrblck!

Yes, I can confirm that upgrading to cuda 11 makes the infs go away.
After updating my nightly installation to “cudatoolkit=11.3,” I get:

1.11.0.dev20220124
11.3
GeForce GTX 1050 Ti
kernel: 1 , stride: 1 , H: 1 , W: 1
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 3 , W: 74
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 3 , W: 75
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 2048 , W: 74
all nans: True   all infs: False
kernel: 3 , stride: 1 , H: 2048 , W: 75
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 1 , W: 1
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 2 , W: 1
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 1 , W: 2048
all nans: True   all infs: False
kernel: 1 , stride: 2 , H: 2 , W: 2048
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 10 , W: 3
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 11 , W: 3
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 3 , W: 14
all nans: True   all infs: False
kernel: 3 , stride: 2 , H: 3 , W: 15
all nans: True   all infs: False

Best.

K. Frank

1 Like

My school’s Nvidia driver might not support CUDA >10.2 :confused: