Bug of pytorch 1.10 for NVIDIA RTX A6000

Hi there,

I ran my code below on RTX A6000 with 2 GPUs or 4 GPUs. However, the CE loss becomes nan after just a few iterations. Then I check, it because the learnable parameters become nan after almost the first backpropagation. Then I test my code on other GPUs such as TITAN RTX, TITAN V, and Tesla V100 (32G). They work well on other GPUs except for RTX A6000. I am wondering if there is a bug on PyTorch for A6000. Could you guys please re-test my code on A6000 again to see if there is a PyTorch bug on A6000 or not?

The command I ran my code is
python -m torch.distributed.launch --master_port=6396 --nproc_per_node=2 debug_train_dist.py --ngpu 2 --reduction 8 -lr 0.001 -epoch 80 -nb_worker 8 -bs 50

The project code is below which helps you to reproduce. By the way, my PyTorch version on RTX A6000 is 1.10.0.dev20210831. The PyTorch version I test on other GPUs is 1.6.0.

import numpy as np
from random import randrange
import torch
import torch.nn as nn
from torch.utils import data
from tqdm import tqdm
import argparse
import json


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-bs', type=int, default=100)
    parser.add_argument('-lr', type=float, default=0.001)
    parser.add_argument('-epoch', type=int, default=80)
    parser.add_argument('-nb_worker', type=int, default=8)
    parser.add_argument('-seed', type=int, default=1234)
    parser.add_argument('-model_params', type=json.loads, default=
    '{"first_conv":3, "in_channels":1, "filts":[128, [128,128], [128,256], [256,256]],' \
    '"blocks":[2,4], "nb_fc_att_node":[1], "nb_fc_node":1024, "gru_node":1024, "nb_gru_layer":1}')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--ngpu', type=int, default=2)
    parser.add_argument('--reduction', type=int, help='reduction rate')
    args = parser.parse_args()
    return args


def keras_lr_decay(step, decay = 0.0001):
    return 1./(1. + decay * step)


class Residual_block_imgs(nn.Module):
    def __init__(self, nb_filts, shift, reduction=8, first = False):
        super(Residual_block_imgs, self).__init__()
        self.shift = shift
        self.lrelu = nn.LeakyReLU()
        self.lrelu_keras = nn.LeakyReLU(negative_slope=0.3)
        self.conv1 = nn.Conv1d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=3,
                               padding=1,
                               stride=1)
        self.bn2 = nn.BatchNorm1d(num_features=nb_filts[1])
        self.conv2 = nn.Conv1d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               padding=1,
                               kernel_size=3,
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv1d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=0,
                                             kernel_size=1,
                                             stride=1)
        else:
            self.downsample = False
        self.mp = nn.MaxPool1d(3)
        channel = nb_filts[0]
        self.avg_pool_t1 = nn.AdaptiveAvgPool1d(1)
        self.avg_pool_t2 = nn.AdaptiveAvgPool1d(1)
        self.down_t1 = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True)
        )
        self.up_diff = nn.Sequential(
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
        self.down_t2 = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        identity = x
        x = self._middle(x)
        out = self.conv1(x)
        out = self.bn2(out)
        out = self.lrelu_keras(out)
        out = self.conv2(out)
        if self.downsample:
            identity = self.conv_downsample(identity)
        out += identity
        out = self.mp(out)
        return out

    def _middle(self, x):
        b, c, _ = x.size()
        length = x[:, :, self.shift:].shape[-1]
        x_t2 = x[:,:,self.shift:]
        x_t1 = x[:,:,:length]
        b_t2, c_t2, _ = x_t2.size()
        y_t2 = self.avg_pool_t2(x_t2).view(b_t2, c_t2)
        b_t1, c_t1, _ = x_t1.size()
        y_t1 = self.avg_pool_t1(x_t1).view(b_t1, c_t1)
        y_t2 = self.down_t2(y_t2)
        y_t1 = self.down_t1(y_t1)
        y = y_t2 - y_t1
        y = self.up_diff(y).view(b, c, 1)
        x = x * y.expand_as(x)
        return x


class Model_imgs(nn.Module):
    def __init__(self, d_args, reduction=8):
        super(Model_imgs, self).__init__()
        self.reduction = reduction
        self.first_conv = nn.Conv1d(
            in_channels=d_args['in_channels'],  # 1
            out_channels=d_args['filts'][0],  # 128
            kernel_size=d_args['first_conv'],  # 3
            stride=d_args['first_conv'])  # 3

        self.first_bn = nn.BatchNorm1d(
            num_features=d_args['filts'][0])  # 128
        self.lrelu_keras = nn.LeakyReLU(
            negative_slope=0.3)

        self.block0 = self._make_layer(
            nb_blocks=d_args['blocks'][0],  # 2
            nb_filts=d_args['filts'][1],  # 128
            shifts=[2667, 889],
            first=True)
        self.block1 = self._make_layer(
            nb_blocks=d_args['blocks'][1],  # 4
            nb_filts=d_args['filts'][2],  # 256
            shifts=[296, 99, 33, 11])
        self.bn_before_gru = nn.BatchNorm1d(
            num_features=d_args['filts'][2][-1])  # 256
        self.gru = nn.GRU(
            input_size=d_args['filts'][2][-1],  # 256
            hidden_size=d_args['gru_node'],  # 1024
            num_layers=d_args['nb_gru_layer'],  # 1
            batch_first=True)
        self.fc1_gru = nn.Linear(
            in_features=d_args['gru_node'],  # 1024
            out_features=d_args['nb_fc_node'])  # 256
        self.fc2_gru = nn.Linear(
            in_features=d_args['nb_fc_node'],  # 256
            out_features=d_args['nb_classes'],  # 6112
            bias=True)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.first_conv(x)
        x = self.first_bn(x)
        x = self.lrelu_keras(x)
        x = self.block0(x)
        x = self.block1(x)
        x = self.bn_before_gru(x)
        x = self.lrelu_keras(x)
        x = x.permute(0, 2, 1)
        self.gru.flatten_parameters()
        x, _ = self.gru(x)
        x = x[:, -1, :]
        code = self.fc1_gru(x)
        code_norm = code.norm(p=2, dim=1, keepdim=True) / 10.
        code = torch.div(code, code_norm)
        out = self.fc2_gru(code)
        return out

    def _make_layer(self, nb_blocks, nb_filts, shifts, first = False):
        layers = []
        for i in range(nb_blocks):
            first = first if i == 0 else False
            shift = shifts[i]
            layers.append(Residual_block_imgs(nb_filts=nb_filts, shift=shift, reduction=self.reduction, first=first))
            if i == 0: nb_filts[0] = nb_filts[1]
        return nn.Sequential(*layers)


class Dataset_imgs(data.Dataset):
    def __init__(self, ):
        a = 1

    def __len__(self):
        return 10000

    def __getitem__(self, index):
        y = randrange(6112)
        X = np.random.random_sample((59049,)).astype(np.float32)
        return X, y


def train_model(model, db_gen, optimizer, epoch, args, device, lr_scheduler, criterion):
    model.train()
    if args.local_rank == 0:
        pbar = tqdm(total=len(db_gen))
    for idx_ct, (m_batch, m_label) in enumerate(db_gen):
        m_batch, m_label = m_batch.to(device), m_label.to(device)
        output = model(m_batch)
        loss = criterion(output, m_label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if args.local_rank == 0:
            pbar.set_description('epoch: %d, cce:%.3f'%(epoch, loss))
            pbar.update(1)
        lr_scheduler.step()
    if args.local_rank == 0:
        pbar.close()


def main():
    args = get_args()
    args.model_params['nb_classes'] = 6112

    world_size = args.ngpu
    torch.distributed.init_process_group(
        'nccl',
        init_method='env://',
        world_size=world_size,
        rank=args.local_rank,
    )

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    cuda = torch.cuda.is_available()
    device = torch.device('cuda' if cuda else 'cpu')

    trainset = Dataset_imgs()

    sampler_devset = torch.utils.data.distributed.DistributedSampler(
        trainset,
        num_replicas=args.ngpu,
        rank=args.local_rank)

    trainset_gen = data.DataLoader(trainset,
                                 batch_size=args.bs,
                                 drop_last=True,
                                 pin_memory=True,
                                 sampler=sampler_devset,
                                 num_workers=args.nb_worker)

    torch.cuda.set_device(args.local_rank)
    model = Model_imgs(args.model_params, args.reduction)
    model.cuda()

    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.local_rank],
        output_device=args.local_rank,
        find_unused_parameters=False)

    criterion= nn.CrossEntropyLoss()

    params = [
        {
            'params': [
                param for name, param in model.named_parameters()
                if 'bn' not in name
            ]
        },
        {
            'params': [
                param for name, param in model.named_parameters()
                if 'bn' in name
            ],
            'weight_decay':
                0
        },
    ]
    optimizer = torch.optim.Adam(params,
                                 lr=args.lr,
                                 weight_decay=0.0001,
                                 amsgrad=True)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: keras_lr_decay(step))

    # Train
    for epoch in range(args.epoch):
        sampler_devset.set_epoch(epoch)
        if args.local_rank == 0:
            print('training epoch:', epoch + 1)
        train_model(model=model,
                    db_gen=trainset_gen,
                    args=args,
                    optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    criterion=criterion,
                    device=device,
                    epoch=epoch)


if __name__ == '__main__':
    main()

If you need other information, please do not hesitate to leave a message. Thanks in advance!

1 Like

Are you using the binaries with CUDA11.3 or 11.5? I assume you are not seeing any issues running the same code with a nightly binary on the other GPUs or are you stuck to 1.6.0 in the other setups (this would help identifying if the issue could be in the framework itself or the libraries)?

Thanks for your swift reply. I check the environment for A6000 using https://raw.githubusercontent.com/pytorch/pytorch/master/torch/ utils/collect_env.py.

Collecting environment information...
PyTorch version: 1.10.0.dev20210831
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.23

Python version: 3.8.11 (default, Aug  3 2021, 15:09:35)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.4.0-142-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000

Nvidia driver version: 465.19.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] torch==1.10.0.dev20210831
[pip3] torchaudio==0.10.0.dev20210831
[pip3] torchvision==0.11.0.dev20210831
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.1.74              h6bb024c_0    nvidia
[conda] mkl                       2021.3.0           h06a4308_520
[conda] mkl-service               2.4.0            py38h7f8727e_0
[conda] mkl_fft                   1.3.0            py38h42c9631_2
[conda] mkl_random                1.2.2            py38h51133e4_0
[conda] numpy                     1.20.3           py38hf144106_0
[conda] numpy-base                1.20.3           py38h74d4b33_0
[conda] pytorch                   1.10.0.dev20210831 py3.8_cuda11.1_cudnn8.0.5_0    pytorch-nightly
[conda] torchaudio                0.10.0.dev20210831            py38    pytorch-nightly
[conda] torchvision               0.11.0.dev20210831      py38_cu111    pytorch-nightly

However, regarding other GPUs, such as TITAN RTX. The 1.6.0 PyTorch environment is shown below

Collecting environment information...
PyTorch version: 1.6.0
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.23

Python version: 3.8.3 (default, Jul  2 2020, 16:21:59)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-4.4.0-142-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: TITAN RTX
GPU 1: TITAN RTX
GPU 2: TITAN RTX
GPU 3: TITAN RTX

Nvidia driver version: 440.31
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] efficientnet-pytorch==0.7.1
[pip3] numpy==1.20.1
[pip3] numpydoc==1.1.0
[pip3] torch==1.6.0
[pip3] torchaudio==0.6.0
[pip3] torchvision==0.7.0
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] efficientnet-pytorch      0.7.1                    pypi_0    pypi
[conda] mkl                       2020.1                      217
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.1.0            py38h23d657b_0
[conda] mkl_random                1.1.1            py38h0573a6f_0
[conda] numpy                     1.18.5                   pypi_0    pypi
[conda] numpy-base                1.18.5           py38hde5b4d6_0
[conda] numpydoc                  1.1.0                      py_0
[conda] pytorch                   1.6.0           py3.8_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchaudio                0.6.0                    pypi_0    pypi
[conda] torchvision               0.7.0                py38_cu101    pytorch

Are the above information helps you to identify what happened?

Yes, thanks! We’ll try to reproduce the issue using your setup. In the meantime you could try to update to the latest nightly with the CUDA11.5 runtime and cuDNN 8.2.0.

Thanks for your advice. I also test my code on 1.10.1. The print("Is cuDNN version:", torch.backends.cudnn.version()) gives me 8200. It still gives me nan error. The collected environment is below.

Collecting environment information...
PyTorch version: 1.10.1
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: version 3.19.6
Libc version: glibc-2.10

Python version: 3.7.4 (default, Aug 13 2019, 20:35:49)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-4.4.0-142-generic-x86_64-with-debian-stretch-sid
Is CUDA available: True
CUDA runtime version: Could not collect

GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000

Nvidia driver version: 465.19.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.10.1
[pip3] torchaudio==0.10.1+cu113
[pip3] torchvision==0.11.2
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.3.1               h2bc3f7f_2
[conda] magma-cuda113             2.5.2                         1    pytorch
[conda] mkl                       2021.3.0           h06a4308_520
[conda] mkl-include               2021.3.0           h06a4308_520
[conda] mkl-service               2.4.0            py37h7f8727e_0
[conda] mkl_fft                   1.3.1            py37hd3c417c_0
[conda] mkl_random                1.2.2            py37h51133e4_0
[conda] numpy                     1.19.5                   pypi_0    pypi
[conda] pytorch                   1.10.1          py3.7_cuda11.3_cudnn8.2.0_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                0.10.1+cu113             pypi_0    pypi
[conda] torchvision               0.11.2               py37_cu113    pytorch

If it’s possible, could you try an NGC PyTorch container: PyTorch | NVIDIA NGC? I’ve tried a few recent images (21.10, 21.08, 21.06) which should correlate to 1.10, 1.10, and 1.9 respectively but couldn’t reproduce the NaN issue. Testing on the same driver is a bit involved; for reference I was on 470.86.

Running the script on 2xA6000, the loss stayed between 8-9 after 80 epochs.

EDIT: I’ve also tried the 1.10+cu111 pip wheel and was unable to reproduce the issue.

Thanks for your reply and try. Could you please tell me the exact command to install 1.10+cu111 using pip wheel in order to help me to verify if I got the same output as yours?

Sure, I used pip3 install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html, but a newer version (e.g., 1.10.1 built with 11.3) should also work.

Thanks for your information. I try to install it following your command in a new conda environment. However, I still encounter a nan problem. I wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py and then run python collect_env.py. The collected environment is below.

Collecting environment information...
PyTorch version: 1.10.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.23

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.4.0-142-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000

Nvidia driver version: 465.19.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.22.1
[pip3] torch==1.10.0+cu111
[conda] numpy                     1.22.1                   pypi_0    pypi
[conda] torch                     1.10.0+cu111             pypi_0    pypi

Could you please verify if my collected information is very different from yours?

Here’s what I got:

Collecting environment information...                                                                                                                                   PyTorch version: 1.10.0+cu111                                                                                                                                           Is debug build: False                                                                                                                                                   CUDA used to build PyTorch: 11.1                                                                                                                                        ROCM used to build PyTorch: N/A                                                                                                                                                                                                                                                                                                                 OS: Ubuntu 20.04.3 LTS (x86_64)                                                                                                                                         GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0                                                                                                                       Clang version: Could not collect                                                                                                                                        CMake version: version 3.21.3                                                                                                                                           Libc version: glibc-2.31                                                                                                                                                                                                                                                                                                                        Python version: 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.11.0-41-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.6.55
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000

Nvidia driver version: 470.86
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.22.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==1.10.0+cu111
[pip3] torch-tensorrt==1.1.0a0
[pip3] torchtext==0.12.0a0
[pip3] torchvision==0.12.0a0
[conda] magma-cuda110             2.5.2                         5    local
[conda] mkl                       2019.5                      281    conda-forge
[conda] mkl-include               2019.5                      281    conda-forge
[conda] mypy_extensions           0.4.3            py38h578d9bd_4    conda-forge
[conda] numpy                     1.22.0           py38h6ae9a64_1    conda-forge
[conda] pytorch-quantization      2.1.2                    pypi_0    pypi
[conda] torch                     1.10.0+cu111             pypi_0    pypi
[conda] torch-tensorrt            1.1.0a0                  pypi_0    pypi
[conda] torchtext                 0.12.0a0                 pypi_0    pypi
[conda] torchvision               0.12.0a0                 pypi_0    pypi

I think the major differences are the Ubuntu version, non-reported cuDNN version, and CUDA runtime version (which should be superseded by 11.1 in the wheels). If possible, I would recommend trying an NGC container to isolate more OS-level differences.

Thanks for offering this information. Since I only have access to one node of the cluster, I asked one of my colleagues to test my code on other nodes. The code works, I believe that the node I used has problems instead of a bug of PyTorch. I will re-check the system of this node to solve this node’s problem.

@ptrblck
I think the node I used has a problem itself instead of pytorch. I do not waste your time reproducing it. Thanks for your help!

That’s interesting as I’ve discussed this possibility with @eqy, too.
Could you run dmesg on the problematic node and check fir Xid errors?

Thanks for your interest. Since the dmesg output log is too long to print to this forum. Is there any way to share the out-of-limit length information?

I wouldn’t need to see the entire output, just the Xid messages if there are any.

I have also seen nans only on my Ampere devices. Many Xid in dmesg. This look like big GPU driver bug to me. @ptrblck Are you able to reproduce?

Could you please provide some additional details (e.g., minimal reproduction script, driver version, GPU model(s))?