Hello
I am currently experiencing an issue during a training, where NaN values appears in every forward pass.
Here is a minimal code that reproduces the error:
import torch
from torch import nn
class Net(nn.Sequential):
def __init__(self):
super(Net, self).__init__(
nn.Conv1d(513, 512, kernel_size=3, bias=False),
nn.SELU(),
nn.Conv1d(512, 128, kernel_size=3, bias=False),
nn.SELU(),
nn.Conv1d(128, 128, kernel_size=3, bias=False),
nn.SELU(),
nn.Conv1d(128, 64, kernel_size=3, bias=False),
nn.SELU(),
)
def forward(self, x):
for i, layer in enumerate(self):
new_x = layer(x)
if new_x.isnan().any():
w = layer.weight
print(f'first nan at stage {i}:')
print(layer)
print(f'input mean={x.mean():.5f}; std={x.std():.5f}; min={x.min():.5f}; max={x.max():.5f}')
print(f'layer weight mean={w.mean():.5f}; std={w.std():.5f}; min={w.min():.5f}; max={w.max():.5f}')
break
x = new_x
return new_x
net = Net()
net = net.cuda()
x = torch.randn((16, 513, 256), device='cuda')
with torch.cuda.amp.autocast(enabled=True):
y = net(x)
assert not y.isnan().any()
print("OK")
Here is the result in my environment:
first nan at stage 0:
Conv1d(513, 512, kernel_size=(3,), stride=(1,), bias=False)
input mean=-0.00087; std=1.00048; min=-4.66932; max=5.45892
layer weight mean=0.00000; std=0.01473; min=-0.02549; max=0.02549
Traceback (most recent call last):
File ".../playground.py", line 41, in <module>
assert not y.isnan().any()
AssertionError
I presume the cause is a material/environment issue because all run good in google colab.
I may have misunderstood the usage of amp, but since this problem occurs in a pytorch-lightning project, I don’t call amp myself.
Here is my environment.
I run the code in a virtualenv in WSL from PyCharm.
My GPU is a GTX 1660 Ti
$ lsb_release -a && uname -a
No LSB modules are available.
Distributor ID: Ubuntu
Description: Ubuntu 20.04.2 LTS
Release: 20.04
Codename: focal
Linux PC-portable-Augustin 5.4.72-microsoft-standard-WSL2 #1 SMP Wed Oct 28 23:40:43 UTC 2020 x86_64 x86_64 x86_64 GNU/Linux
$ ./virtualenvs/torch/bin/pip freeze
absl-py==0.12.0
aiohttp==3.7.4.post0
appdirs==1.4.3
async-timeout==3.0.1
attrs==20.3.0
CacheControl==0.12.6
cachetools==4.2.1
certifi==2019.11.28
chardet==3.0.4
colorama==0.4.3
contextlib2==0.6.0
cycler==0.10.0
distlib==0.3.0
distro==1.4.0
fsspec==0.8.7
future==0.18.2
google-auth==1.28.0
google-auth-oauthlib==0.4.4
grpcio==1.36.1
html5lib==1.0.1
idna==2.8
ipaddr==2.2.0
kiwisolver==1.3.1
lockfile==0.12.2
Markdown==3.3.4
matplotlib==3.4.1
msgpack==0.6.2
multidict==5.1.0
numpy==1.20.2
oauthlib==3.1.0
packaging==20.3
pep517==0.8.2
Pillow==8.2.0
progress==1.5
protobuf==3.15.7
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.6
python-dateutil==2.8.1
pytoml==0.1.21
pytorch-lightning==1.2.6
PyYAML==5.3.1
requests==2.22.0
requests-oauthlib==1.3.0
retrying==1.3.3
rsa==4.7.2
six==1.14.0
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
torch==1.8.1+cu111
torchaudio==0.8.1
torchmetrics==0.2.0
torchvision==0.9.1+cu111
tqdm==4.59.0
typing-extensions==3.7.4.3
urllib3==1.25.8
webencodings==0.5.1
Werkzeug==1.0.1
yarl==1.6.3
Does anyone have any idea what is wrong in my system?
Thank you