Numerical error on A100 GPUs

Hi, I’m seeing a numerical error of F.conv2d computation between on CPU and on GPU. This behaviour is only observed in A100 GPU and with the recent versions of pytorch.

test.py:

import numpy as np
import torch
from torch.nn import functional as F

print("TORCH VERSION: {}".format(torch.version.__version__))
print("CUDA VERSION: {}".format(torch.version.cuda))
print("CUDNN VERSION: {}".format(torch.backends.cudnn.version()))
print('')

x = torch.from_numpy(np.load('x.npz')['data'])
W = torch.from_numpy(np.load('W.npz')['data'])

y_cpu = F.conv2d(x, weight=W, padding=1)
print(f'on CPU: {y_cpu.sum().item():.4f}')

y_cuda = F.conv2d(x.to('cuda'), weight=W.to('cuda'), padding=1)
print(f'on GPU: {y_cuda.sum().item():.4f}')

You can find the data files (W.npz, x.npz) here and here if you want to reproduce the result yourself.

First, here’s the result for pytorch 1.10.1
torch.utils.collect_env:

Collecting environment information...
PyTorch version: 1.10.1+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.27

Python version: 3.8.0 (default, Dec  9 2021, 17:53:27)  [GCC 8.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1030-aws-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.1.105
GPU models and configuration: 
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB
GPU 4: A100-SXM4-40GB
GPU 5: A100-SXM4-40GB
GPU 6: A100-SXM4-40GB
GPU 7: A100-SXM4-40GB

Nvidia driver version: 450.80.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.10.1+cu111
[conda] Could not collect

python test.py:

TORCH VERSION: 1.10.1+cu111
CUDA VERSION: 11.1
CUDNN VERSION: 8005

on CPU: -62344.9922
on GPU: -62346.3828

Note the numerical difference of the F.conv2d on CPU and GPU. Exactly the same error happens when I upgraded the pytorch to v1.13 or downgraded the pytorch version to v1.9. And when I downgraded the pytorch to v1.8, the error is gone:

torch.utils.collect_env:

Collecting environment information...
PyTorch version: 1.8.1+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB
GPU 4: A100-SXM4-40GB
GPU 5: A100-SXM4-40GB
GPU 6: A100-SXM4-40GB
GPU 7: A100-SXM4-40GB

Nvidia driver version: 450.80.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.8.1+cu111
[conda] Could not collect

python test.py:

TORCH VERSION: 1.8.1+cu111
CUDA VERSION: 11.1
CUDNN VERSION: 8005

on CPU: -62344.9922
on GPU: -62344.9922

Notes:

  • This seems to happen particularly for A100. My collegues did the same test on different GPUs, and told me there’s no error (I don’t have the result).
  • While it may seem small, the impact of this error is very significant when applied to deep conv nets, espeically when the prediction requires high accuracy. For example, when I ran inference with a same model, I got 39% mAP (nuscenes 3D detection, higher the better) on the setting with no error on test.py, but it dropped to 35% when I ran the same model on the setting where test.py shows this error.
  • I could just stick to pytorch 1.8.1, but for some reason the training is much slower (almost 50%) than when using higher versions of pytorch. I wish I could train faster with the higher versions but without this error.

Does anybody have a clue about this?

2 Likes
  1. I don’t know which input data range you are using, but based on the errors I would guess they are caused by the TF32 numerical precision. Could you disable it and recheck the results via:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
  1. Could you share the model and describe your use case a bit more, please? The assumption is that no convergence issues are caused by TF32, but your use case sounds concerning.

  2. Check point 1. Disabling TF32 would still yield a performance drop but might still be faster overall than dropping to an older PyTorch release (with older CUDA libraries).

2 Likes

Hi @ptrblck, thanks for the response (not only on this topic but also on many others from which I found solutions)!

  1. Yup, that two lines solved my problem. Now I can reproduce my results. Thanks!

  2. The model is from this code. In fact, now I think I know what’s happening, after reading about tf32. My model is for predicting 3D bounding boxes from images, and it requires high precision (e.g. fp32), especially in the part of predicting some geometric property such as orientation. For this part of computation, I even had to escape the mixed precision (amp) because fp16 caused the training to diverge. (In fact, for this computation I’m heavily using pytorch3d). Looking at the descriptoin of tf32, it looks a lot more like fp16, so it make sense that tf32 is not precise enough for my application.

  3. Yup, definitely will do this.

Thanks!

Hi @dnnspark

I have some questions.

  1. Is it possible to share the trained weights?
  2. Does scripts/train.py evaluation calculate mAP score? If so, what’s the name of the score? Is it mean_ap?

basckground
I ran dd3d’s scripts/train.py (DLA34 model) using the pretrained weights of GitHub - TRI-ML/dd3d: Official PyTorch implementation of DD3D: Is Pseudo-Lidar needed for Monocular 3D Object detection? (ICCV 2021), Dennis Park*, Rares Ambrus*, Vitor Guizilini, Jie Li, and Adrien Gaidon. and nuScenes dataset with random seed fixed without test time augmentation.
I used the following command: python scripts/train.py +experiments=dd3d_nusc_dla34 EVAL_ONLY=True MODEL.CKPT=/mkozuki/storage/2022/04/12_dd3d_tf32/dla34.pth TEST.IMS_PER_BATCH=12 TEST.AUG.ENABLED=False with&without TF32 enabled.

In log file, some values such as “nusc_val-subsample-8/tp_errors/trans_err” look up to FP32/TF32.

Hi @crcrpar, please find the weight here. To reproduce the test result, try run this:

python scripts/train.py +experiments=dd3d_nusc_v99 MODEL.CKPT=model_final.pth EVAL_ONLY=True TEST.IMS_PER_BATCH=12 

With enough GPU memory, you can run with larger value for IMS_PER_BATCH, but make sure it’s a multiple of 6 (e.g. 18, 24, etc).

And yes, mean_ap is the metric.

Thanks for the swift response.

could you check if the visibility of the weights? I couldn’t access it.

Sorry for the late respone @crcrpar . Could you try this link?

1 Like

Thank you so much for being kind and helpful. With your help, I think I finally reproduced the mAP, 0.389541 for FP32 and 0.355974 for TF32 without TTA.

You’re welcome @crcrpar. I’m looking forward to updates in TF32 so that it can be used in my scenarios.

Hi @dnnspark My colleague identified the cause and figured out the resolution. so let me share it with you.

The mAP drop with TF32 can be alleviated or removed by disabling TF32 during post processing deliberately. To be more specifically, disable TF32 for pytorch3d.Transform3D.get_matrix for this case (= nuScenes dataset case).
You can disable TF32 in post processing of inference by enclosing the inference or get_matrix with torch.backends.cuda.matmul.allow_tf32 = False and torch.backends.cuda.matmul.allow_tf32 = True.

4x4 matrices seem so small that the different numerical format can affect mAP.

Thank you for being so cooperative and baring with slow me