Hi, I’m seeing a numerical error of F.conv2d
computation between on CPU and on GPU. This behaviour is only observed in A100 GPU and with the recent versions of pytorch.
test.py:
import numpy as np
import torch
from torch.nn import functional as F
print("TORCH VERSION: {}".format(torch.version.__version__))
print("CUDA VERSION: {}".format(torch.version.cuda))
print("CUDNN VERSION: {}".format(torch.backends.cudnn.version()))
print('')
x = torch.from_numpy(np.load('x.npz')['data'])
W = torch.from_numpy(np.load('W.npz')['data'])
y_cpu = F.conv2d(x, weight=W, padding=1)
print(f'on CPU: {y_cpu.sum().item():.4f}')
y_cuda = F.conv2d(x.to('cuda'), weight=W.to('cuda'), padding=1)
print(f'on GPU: {y_cuda.sum().item():.4f}')
You can find the data files (W.npz
, x.npz
) here and here if you want to reproduce the result yourself.
First, here’s the result for pytorch 1.10.1
torch.utils.collect_env
:
Collecting environment information...
PyTorch version: 1.10.1+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.27
Python version: 3.8.0 (default, Dec 9 2021, 17:53:27) [GCC 8.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1030-aws-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration:
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB
GPU 4: A100-SXM4-40GB
GPU 5: A100-SXM4-40GB
GPU 6: A100-SXM4-40GB
GPU 7: A100-SXM4-40GB
Nvidia driver version: 450.80.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.10.1+cu111
[conda] Could not collect
python test.py
:
TORCH VERSION: 1.10.1+cu111
CUDA VERSION: 11.1
CUDNN VERSION: 8005
on CPU: -62344.9922
on GPU: -62346.3828
Note the numerical difference of the F.conv2d
on CPU and GPU. Exactly the same error happens when I upgraded the pytorch to v1.13 or downgraded the pytorch version to v1.9. And when I downgraded the pytorch to v1.8, the error is gone:
torch.utils.collect_env
:
Collecting environment information...
PyTorch version: 1.8.1+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB
GPU 4: A100-SXM4-40GB
GPU 5: A100-SXM4-40GB
GPU 6: A100-SXM4-40GB
GPU 7: A100-SXM4-40GB
Nvidia driver version: 450.80.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.8.1+cu111
[conda] Could not collect
python test.py
:
TORCH VERSION: 1.8.1+cu111
CUDA VERSION: 11.1
CUDNN VERSION: 8005
on CPU: -62344.9922
on GPU: -62344.9922
Notes:
- This seems to happen particularly for A100. My collegues did the same test on different GPUs, and told me there’s no error (I don’t have the result).
- While it may seem small, the impact of this error is very significant when applied to deep conv nets, espeically when the prediction requires high accuracy. For example, when I ran inference with a same model, I got 39% mAP (nuscenes 3D detection, higher the better) on the setting with no error on
test.py
, but it dropped to 35% when I ran the same model on the setting wheretest.py
shows this error. - I could just stick to pytorch 1.8.1, but for some reason the training is much slower (almost 50%) than when using higher versions of pytorch. I wish I could train faster with the higher versions but without this error.
Does anybody have a clue about this?