.cuda() extremely slow after calling loss.backward()

I have a model that I apply to 3D data. To provide an example I created a random dataset that always return the same tensor of the wanted size.
When executing the code below, the .cuda() command is extremely slow (4s) except the first call (0.008s).
If I comment loss.backward() it becomes much faster (0.008s).

I also put below the reports of torch.utils.bottleneck with and without loss.backward().

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from time import time

class RandomDataset(Dataset):

    def __init__(self, size=(1, 169, 208, 179), length=18):
        self.item = torch.rand(size)
        self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.item


class Test(nn.Module):
    """
    Classifier for a multi-class classification task
    """
    def __init__(self):
        super(Test, self).__init__()

        self.features = nn.Sequential(
            nn.Conv3d(1, 8, 3),
            nn.BatchNorm3d(8),
            nn.ReLU(),
            nn.MaxPool3d(2, 2),

            nn.Conv3d(8, 16, 3),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d(2, 2),

            nn.Conv3d(16, 32, 3),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.MaxPool3d(2, 2),

            nn.Conv3d(32, 64, 3),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(2, 2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(64 * 792, 1000),
            nn.ReLU(),

            nn.Linear(1000, 100),
            nn.ReLU(),

            nn.Linear(100, 2)

        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x


if __name__ == "__main__":
    batch_size = 3

    dataset = RandomDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    model = Test()
    model = model.cuda()

    total_time=0
    criterion = nn.CrossEntropyLoss()
    for i, data in enumerate(dataloader):
        t0 = time()
        data_gpu = data.cuda()
        t1 = time()
        total_time += t1 - t0
        print("Loading data on GPU", t1 - t0)

        output = model(data_gpu)
        labels = torch.Tensor([0] * batch_size).long().cuda()
        loss = criterion(output, labels)
        loss.backward()

print("Mean time on loading data on GPU:", total_time / (len(dataset) / batch_size))

With loss.backward()

--------------------------------------------------------------------------------
  Environment Summary
--------------------------------------------------------------------------------
PyTorch 1.0.0 compiled w/ CUDA 8.0.61
Running with Python 3.6 and 

`pip3 list` truncated output:
numpy (1.14.3)
--------------------------------------------------------------------------------
  cProfile output
--------------------------------------------------------------------------------
         8396 function calls (7632 primitive calls) in 23.731 seconds

   Ordered by: internal time
   List reduced from 268 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       46   22.976    0.499   22.976    0.499 {method 'cuda' of 'torch._C._TensorBase' objects}
       18    0.317    0.018    0.317    0.018 {method 'uniform_' of 'torch._C._TensorBase' objects}
       18    0.259    0.014    0.259    0.014 {built-in method addmm}
        6    0.086    0.014    0.086    0.014 {built-in method stack}
        1    0.042    0.042    0.042    0.042 {built-in method rand}
        6    0.018    0.003    0.018    0.003 {method 'run_backward' of 'torch._C._EngineBase' objects}
       24    0.005    0.000    0.005    0.000 {built-in method conv3d}
       24    0.002    0.000    0.002    0.000 {built-in method batch_norm}
       36    0.002    0.000    0.002    0.000 {built-in method threshold}
        1    0.002    0.002   23.731   23.731 script.py:1(<module>)
      404    0.002    0.000    0.004    0.000 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py:537(__setattr__)
       24    0.002    0.000    0.002    0.000 {built-in method torch._C._nn.max_pool3d_with_indices}
       24    0.001    0.000    0.004    0.000 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:58(forward)
   150/12    0.001    0.000    0.276    0.023 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py:483(__call__)
1224/1003    0.001    0.000    0.004    0.000 {built-in method builtins.isinstance}

without loss.backward()

--------------------------------------------------------------------------------
  Environment Summary
--------------------------------------------------------------------------------
PyTorch 1.0.0 compiled w/ CUDA 8.0.61
Running with Python 3.6 and 

`pip3 list` truncated output:
numpy (1.14.3)
--------------------------------------------------------------------------------
  cProfile output
--------------------------------------------------------------------------------
         8336 function calls (7572 primitive calls) in 4.048 seconds

   Ordered by: internal time
   List reduced from 263 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       46    3.318    0.072    3.318    0.072 {method 'cuda' of 'torch._C._TensorBase' objects}
       18    0.312    0.017    0.312    0.017 {method 'uniform_' of 'torch._C._TensorBase' objects}
       18    0.258    0.014    0.258    0.014 {built-in method addmm}
        6    0.082    0.014    0.082    0.014 {built-in method stack}
        1    0.040    0.040    0.040    0.040 {built-in method rand}
       24    0.006    0.000    0.006    0.000 {built-in method conv3d}
       36    0.004    0.000    0.004    0.000 {built-in method threshold}
       24    0.003    0.000    0.003    0.000 {built-in method batch_norm}
       24    0.003    0.000    0.003    0.000 {built-in method torch._C._nn.max_pool3d_with_indices}
      404    0.002    0.000    0.004    0.000 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py:537(__setattr__)
        1    0.002    0.002    4.048    4.048 script.py:1(<module>)
   150/12    0.001    0.000    0.281    0.023 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py:483(__call__)
       24    0.001    0.000    0.005    0.000 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:58(forward)
    100/1    0.001    0.000    0.003    0.003 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/abc.py:196(__subclasscheck__)
     24/1    0.001    0.000    1.952    1.952 /home/elina.thibeausutre/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py:185(_apply)

Hi,

Keep in mind that cuda api is asynchronous and will only synchronize when it interacts with the cpu. So I guess the backward call is super fast but because it’s only the time to queue the work onto the gpu. But as soon as you want to access this information on the cpu (or send data to the gpu), you have to wait for all these computations to be done. So the long runtime for the .cuda() that you see is most certainly the time to actually compute your backward. You can use torch.cuda.synchronize() to force synchronization in your code. Adding one just after the backward call should reduce the runtime of the .cuda() !

Thank you for your reply, indeed it reduced the .cuda() time to a normal call. However, this means that the .backward() call lasts nearly 4s. Is it normal ? Or can I do something about it ?

This depends on your network.
The backward can roughly take up to 2x the time of the forward, usually around 1x the forward time.

By the way there are notes in the doc about using the bottleneck tool with CUDA, maybe they will help you have proper timings.

I retimed everything myself using torch.cuda.synchronize() (see the following code and output) and my forward pass would be taking ~0.3s whereas my backward pass is taking ~3.9s. Are these the correct timings and if so is there anything I can do ?

    for i, data in enumerate(dataloader):
        t0 = time()
        data_gpu = data.cuda()
        t1 = time()
        total_time += t1 - t0
        print("Loading data on GPU", t1 - t0)

        output = model(data_gpu)
        torch.cuda.synchronize()
        t2 = time()
        print("Real time forward pass", t2 - t1)
        labels = torch.Tensor([0] * batch_size).long().cuda()
        loss = criterion(output, labels)
        torch.cuda.synchronize()
        t3 = time()
        print("Real time for loss computation", t3 - t2)
        loss.backward()
        torch.cuda.synchronize()
        t4 = time()
        print("Real time for backward pass", t4 - t3)

outputs:

Real time forward pass 0.33179545402526855
Real time for loss computation 0.0008327960968017578
Real time for backward pass 3.920534610748291
Loading data on GPU 0.008051156997680664
Real time forward pass 0.26632165908813477
Real time for loss computation 0.0001857280731201172
Real time for backward pass 3.930469512939453
Loading data on GPU 0.00805807113647461
Real time forward pass 0.2658727169036865
Real time for loss computation 0.00015234947204589844
Real time for backward pass 3.920789957046509
Loading data on GPU 0.008044242858886719
Real time forward pass 0.2670416831970215
Real time for loss computation 0.00017786026000976562
Real time for backward pass 3.932731866836548
Loading data on GPU 0.008074760437011719
Real time forward pass 0.2663881778717041
Real time for loss computation 0.00015616416931152344
Real time for backward pass 3.9343338012695312
Loading data on GPU 0.008038997650146484
Real time forward pass 0.2661571502685547
Real time for loss computation 0.00015997886657714844
Real time for backward pass 3.925018072128296
1 Like

Hi @14thibea, did you accelerate the backward? If so, how did you get that? I am having exactly the same problem as yours.

Same problem.
The same code in one machine backward time is 14.7s
and another one is 62.3s

I think this is the cuda or cudnn problems.