NVIDIA A100 80G very slow during training

I just bought an A100 80G NVIDIA GPU, but I am not able to use it since it is very slow. It seems to be 5x time slower than an A100 40G, when I do a bert model fine tuning.
Those are my venv information and the nvdia-smi out:

PyTorch version: 1.12.0+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Rocky Linux release 8.6 (Green Obsidian) (x86_64)
GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-10)
Clang version: Could not collect
CMake version: version 3.20.2
Libc version: glibc-2.28

Python version: 3.9.7 (default, Sep 16 2021, 13:09:58)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.18.0-372.9.1.el8.x86_64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 515.48.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] torch==1.12.0+cu116
[pip3] torchaudio==0.12.0+cu116
[pip3] torchvision==0.13.0+cu116
[conda] blas                      1.0                         mkl  
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.20.3           py39hf144106_0  
[conda] numpy-base                1.20.3           py39h74d4b33_0  
[conda] numpydoc                  1.1.0              pyhd3eb1b0_1

 NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7  

Do you know by chance what can be the problem? In this configuration is impossible to use it.

Thank you !

We would need more information to be able to debug this issue as we are not seeing these slowdowns on 80GB vs. 40GB A100s internally on BERT.

1 Like

Sure, sorry. Which type of information I can give you?

I solved using this torch version details in my env: