Training on RTX3090, tensor.to(device) very slow on some networks

Issue description

Recently our lab set up a new machine with RTX3090, I installed GPU driver 460.32, CUDA 11.2, and pytorch through

conda install pytorch torchvision torchaudio cudatoolkit=11.0 -c pytorch

Then I tested with a seq2seq model (LSTM->LSTM) I used before, training very fast, working fine.

However, when I use this machine to train a TextCNN classification model, I find it is much slower than even when I did it on my laptop with GTX1660ti (cuda10.2+torch1.7.0).

There is no error message during training, though. To analyze where the problem is, I just added some timer in the program and found the difference.

    for step, (x_batch, y_batch) in enumerate(train_loader):
        t1=time.clock()
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        t2=time.clock()
        print('transfer data needs %s ms' % ((t2 - t1) * 1000))

        output = model(x_batch)
        #training steps...
        t3=time.clock()
        print('training step uses %s ms' % ((t3 - t2) * 1000))

On my laptop, the data loading time is just 0.06ms, the following training steps take about 3.5~4.5ms.
However, on the 3090 machine, the data loading time needs about 102ms, while training steps only take only 2.3ms.
I also tried the same training operation on the old server of my lab with RTX2080 (cuda11.0 torch1.7.1), data loading time 0.03ms, training steps take 1.9ms. (P.S. I also noticed that on 2080 machine the ‘GPU-Util Compute M.’ (can be seen on nvidia-smi) was just about 54% through the entire training procedure, but the 3090 machine kept 100% from the start.)
The codes put on these 3 machine are exactly the same, but why the tensor.to(device) on 3090 machine is so slow?
(By the way, I also tried a 3MLP GAN network which works fine on my laptop and old 2080 server, too…But again, tensor.to(device) on 3090 is slow…)

The data loading methods of Seq2Seq and TextCNN are also almost the same, using

train = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)

System Info

Here is the system info of that new server with RTX3090.

PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.2.142
GPU models and configuration: GPU 0: GeForce RTX 3090
Nvidia driver version: 460.32.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] numpy==1.19.5
[pip] numpydoc==0.9.2
[pip] torch==1.7.1
[pip] torchaudio==0.7.2
[pip] torchvision==0.8.2+cu110
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.0.221 h6bb024c_0
[conda] mkl 2020.0 166
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] numpy 1.18.1 py37h4f9e942_0
[conda] numpy-base 1.18.1 py37hde5b4d6_1
[conda] numpydoc 0.9.2 py_0
[conda] pytorch 1.7.1 py3.7_cuda11.0.221_cudnn8.0.5_0 pytorch
[conda] torchaudio 0.7.2 pypi_0 pypi
[conda] torchvision 0.8.2+cu110 pypi_0 pypi

1 Like

CUDA operations are executed asynchronously, so you would need to synchronize the code before starting and stopping the timer.
Also, note that the very first CUDA operation initializes the CUDA context etc. and is thus slower.
You could use the torch.utils.benchmark utilities to profile specific operations, which would make sure to properly add warmup iterations and synchronizations.

If you are seeing a performance regression in the model execution, please post the model definition (if possible), so that we could test it with the latest cudnn release and check the performance on this GPU.
That being said, note that we are currently facing an issue in the PyTorch binaries with statically linked cudnn, which removes xmma kernels as described here, and we are still working on forcing these kernels into the binaries.

Thank you for your reply!
From a macro point of view, because I set the result to be printed every 100 steps, I just saw that on my 1660ti laptop with cuda10.2 and torch1.7.0, this output is very fast (because textcnn is not a very complicated network), and about 2 lines of output per second can be jumped out. But on 3090, I need to wait for nearly 10 seconds to see a result output.

Also thank you for pointing out the asynchronization problem. I am sorry but I am not so sure how to use torch.utils.benchmark in multiple lines code… But I revised my timing method as

    for step, (x_batch, y_batch) in enumerate(train_loader):
        torch.cuda.synchronize()
        t1=time.clock()
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        torch.cuda.synchronize()
        t2=time.clock()
        print('transfer data needs %s ms' % ((t2 - t1) * 1000))

        torch.cuda.synchronize()
        t3=time.clock()
        output = model(x_batch)
        output = output.squeeze()
        loss = loss_func(output, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
        t4=time.clock()
        print('training step uses %s ms' % ((t4 - t3) * 1000))

As I usually train the network for 7 epochs, to make the network ‘warm-up’, I just record the time result of the last steps of the last epoch.
The situation indeed changed! This time on the 3090 machine, the data loading time needs about 0.038ms, while training steps take around 106ms. While my laptop still needs around 0.07ms to transfer data and 4ms to do the train step…

Then the problem becomes why the GPU training is slow on 3090, so I used another method torch.autograd.profiler.profile() to see what happens:

    for step, (x_batch, y_batch) in enumerate(train_loader):
        torch.cuda.synchronize()
        t1=time.clock()
        with torch.autograd.profiler.profile(use_cuda=True) as prof:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
        torch.cuda.synchronize()
        t2=time.clock()
        print('transfer data needs %s ms' % ((t2 - t1) * 1000))
        print(prof)

And I also use this method on training part mentioned above.
Due to it that the result is too long, I will add two pictures. It seems that the aten::cudnn_convolution operation took the most time? Which seems to have some similarities with the problem you mentioned. (We both used a convolutional structure, I will post my model definition on the next reply.)

Finally, by the way, though I do not quite understand the linked binary… as it seems that the problem happens on cudnn, I just remember that I installed the cudnn library through package-manager-ubuntu-install, maybe this can also be a factor causing this problem? I will try to copy cudnn to /usr/local/cuda-11.2/ rather than /usr/lib/x86_64-linux-gnu/ later.


And here is my model definition. A typical TextCNN network used for NLP intention classification.

def create_emb_layer(weights_matrix, non_trainable=False):
    # how many words in dict (matrix), embedding dim
    num_embeddings, embedding_dim = weights_matrix.shape
    emb_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
    emb_layer.weight = nn.Parameter(torch.tensor(weights_matrix, dtype=torch.float32))
    if non_trainable:
        emb_layer.weight.requires_grad = False

    return emb_layer, num_embeddings, embedding_dim

class TextCNN(nn.Module):
    def __init__(self, weights_matrix, class_num, kernel_num, kernel_sizes):
        super(TextCNN, self).__init__()
        Ci = 1
        Co = kernel_num

        self.embd, embd_num, embd_dim = create_emb_layer(weights_matrix, False)
        self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (k, embd_dim)) for k in kernel_sizes])
        # self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(Co * len(kernel_sizes), class_num)

    def forward(self, x):
        x = self.embd(x) # (batch_N, token_num(word in one sent), embd_dim)
        x = x.unsqueeze(1) # (N, Ci(channel, for text only 1, token_num, embd_dim)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
        x = torch.cat(x, 1) # concat results of 3 size kernel
        # x = self.dropout(x)
        logit = self.fc(x)
        return logit

Before training, I extracted pre-trained word embeddings from fasttext English then use it to initialize embedding layer. The hyperparameters and model initialization code is:

# Hyperparameters
LR = 0.001
EPOCH = 7
BATCH_SIZE = 16
EMBD_DIM = 300
KERNEL_NUM = 16
KERNEL_SIZES = [3, 4, 5]
CLASS_NUM = 150

embedding_matrix = np.load('Embd_matrix_8000x300.npy')
model = TextCNN(embedding_matrix, CLASS_NUM, KERNEL_NUM, KERNEL_SIZES).to(device)

Thanks for the model definition! Could you post the input shapes you are using as well, please?

Your local CUDA toolkit and cudnn library won’t be used, if you’ve installed the conda binaries or the pip wheels.
You could indeed build from source and compare the speed. If you are seeing a speedup it might come from a different cudnn version and/or from the usage of the missing xmma kernels in the binary.

Once you provide the input shapes, I could check the model perf. using the latest cudnn release as well as internal versions.

2 Likes

Okay, my input includes X and Y, where the X are the token-sequences that comes from actual sentences which passed through the tokenizer, the shape is (18000, 23), as there is 18000 data and the padding size is 23.
For example, part of the X is:

[[   1    8 2860   45    4   73    5  124    4  520    7   44    4  156
    49  553    2    0    0    0    0    0    0]
 [   1   15    7   23   11   13    5  124    4   14  121  242   52  401
    17  401    2    0    0    0    0    0    0] ...]

Then, Y are the label-tensors for each token-sequence in X, the shape is (18000,). For example, part of the Y is:

[0, 0, 0, 1, 1, ...]

And I use

train = torch.utils.data.TensorDataset(x_train, y_train)
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)

to put them together to train.

Thank you for the suggestion!!
I built the pytorch from source following our guide with cuda 11.0 and finally succeed! (P.S. I had ever tried using cuda 11.0 + conda-binaries-based pytorch, still slow, so build from source and install should be the solution.)
The running time finally comes to a satisfying level! Where the data loading took 0.038ms and training steps took only 2.05ms!

Note that at first I failed to build pytorch because a library reference not find error, as I used cuda 11.2 (though NVIDIA said the cuda 11.2 should be compatible with cuda 11.0?) So I uninstalled cuda 11.2 and installed cuda 11.0 and cudnn as normal, then successfully built pytorch from source and installed it, tested the program, whose result finally became good.
Really thank you for patiently helping me solve the problem! :blush:

By the way, it seems that the local cuda and cudnn was indeed not used, as you said:

But I still do not know why I can use conda-binaries-based pytorch normally on my old lab
server with RTX2080, cuda11.0?.. Maybe because the latest Ampere architecture of RTX 30xx?

1 Like

That’s good to hear, so I assume an update in the cudnn version as well as potential calls into the missing kernels might have accelerated the workload! :slight_smile:

Could you post the error message you were seeing? As we are building it with CUDA11.2, it would be good to see what wasn’t working in your setup.

The conda binaries ship with the specified CUDA runtime (as well as cudnn, NCCL etc.), which is in your case 11.0. This CUDA release supports multiple GPU architectures, where your RTX2080 (sm_75) is included.

I am really sorry but as I deleted ‘pytorch’ folder after I failed to build pytorch with CUDA 11.2, it may be hard to find the full error message at that time (By the way, I remember I used this CUDA 11.2 with cudnn 8.1, I installed it by decompressing cudnn-11.2-linux-x64-v8.1.0.77.tgz , copying the files with

cp cuda/lib64/* /usr/local/cuda-11.2/lib64/
cp cuda/include/* /usr/local/cuda-11.2/include/

Fortunately, I remember some of the details when I encountered this problem. (As I googled some of the error messages.) :sweat_smile: I hope it could do some help.

  1. I remember we need about 5990 files to build, right? Then my build error happened on about [5660/5990]. At that point the setup.py seemed building ‘test cudnn library’? or say ‘building with cudnn’ part? Then the error happened.

  2. The final error message is
    subprocess.CalledProcessError: Command '['cmake', '--build', '.', '--target', 'install', '--config', 'Release', '--', '-j', '32']' returned non-zero exit status 1.
    Before that there were many error messages coming out like:
    libcublas.so: undefined reference to xxxxxx
    And for now I can find only one exact record that I have ever searched with google… Sorry for that
    libcublas.so: undefined reference to `cublasltbiimatmulalgogetheuristic@libcublaslt.so.11'

I searched github and our forum but they were not my situation. Then I searched the internet and found that

One person said in his blog that he is now using CUDA10.1, but he wants to compile a program based on CUDA10.0.
As a result, a similar error would appear, like libnvinfer.so: undefined reference to `__cudaPushCallConfiguration@libcudart.so.10.0'.
So he manually modified cuda100wrapper.map, disguising cudaGetDeviceCount@libcudart.so.10.1 as cudaGetDeviceCount@libcudart.so.10.0.
But he said he did this because of his work needs. He mentioned that actually it could be much faster to avoid this error if you can uninstall CUDA10.1 and then install CUDA10.0.

So I unistalled CUDA11.2 and installed CUDA11.0 and also cudnn with cudnn-11.2-linux-x64-v8.1.0.77.tgz, and this time the building procedure was successful.

1 Like