How to reduce cudaStreamSynchronize time

I’m using Colab T4 GPU, I tried to use it’s TPU but I was getting JAX error, so I gave up.

My training data is around 13500 images, and my batch size is 24, I did a lot of research into optimization trying to get my model to train faster, the best I achieved was 42 minutes/epoch, and that’s a bit slow, since my loss is not decreasing and I need to keep tweaking with my net.

This is my current code:


import os
import sys
from google.colab import drive
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision as tv
from torchvision.transforms import v2


DRIVE_DEFAULT_PATH = '/content/drive'
if not os.path.exists(DRIVE_DEFAULT_PATH):
 drive.mount(DRIVE_DEFAULT_PATH)
DRIVE_DEFAULT_PATH = DRIVE_DEFAULT_PATH + '/MyDrive'
CLASS_DEFAULT_PATH = '/RNP'
ASSIGNMENT_PATH = '/Trabalho 01/Sports'
WORK_PATH = DRIVE_DEFAULT_PATH + CLASS_DEFAULT_PATH + ASSIGNMENT_PATH


def setLoader(path, batch_size, train):
  if(train):
    transforms = v2.Compose([v2.ToImage(), 
                             v2.ToDtype(torch.float32, scale=True),
                             v2.Normalize(mean=[0.4713, 0.4699, 0.4548], std=[0.3081, 0.3020, 0.2961])])
    return torch.utils.data.DataLoader(tv.datasets.ImageFolder(WORK_PATH + path, transform=transforms),
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=2,
                                      pin_memory=True,
                                      prefetch_factor=4)


class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.step1 = nn.Sequential(nn.Conv2d(3,400,3,padding=1),
                               nn.ReLU(),
                               #2nd
                               nn.Conv2d(400,400,5,padding=2),
                               nn.ReLU(),
                               nn.MaxPool2d(2),
                               #3rd
                               nn.Conv2d(400,200,3,padding=1),
                               nn.ReLU(),
                               #nn.MaxPool2d(2),
                               #4th
                               nn.Conv2d(200,200,7),
                               nn.ReLU(),
                               #nn.MaxPool2d(2),
                               #5th
                               nn.Conv2d(200,100,5,padding=2),
                               nn.ReLU(),
                               nn.MaxPool2d(3),                                                  
                               )
    self.step2 = nn.Sequential(nn.LazyLinear(100),
                               nn.Softmax(dim=1))
  def forward(self, x):    
    return self.step2(torch.flatten(self.step1(x), start_dim=1))


def train():
  lr = 0.1
  num_epochs = 100
  train_then_validation = True
  model = CNN().to(device)
  loss_fn = nn.CrossEntropyLoss()
  optim = torch.optim.SGD(model.parameters(), lr=lr)
  loss = None
  for epoch in range(num_epochs):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
      with record_function("model_inference"):
        for batch, (X,y) in enumerate(train_loader):
          X,y = X.to(device), y.to(device)
          optim.zero_grad()
          y_hat = model(X)
          loss =  loss_fn(y_hat,y)      
          loss.backward()      
          optim.step()
          if (batch % 5 == 0):
            print(f'Batch: {batch}, Loss: {loss}')
          if(batch == 5):
            break  
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
    break
    print(f'Epoch: {epoch}, Loss: {loss}')

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
torch.backends.cudnn.benchmark = True
print(f"Using {device} device")
from torch.profiler import profile, record_function, ProfilerActivity
train_loader = setLoader("/train", 24, True)
train()

I'm breaking on batch 5 to profile.

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference         0.96%     265.902ms        99.87%       27.790s       27.790s       0.000us         0.00%        4.356s        4.356s             1  
                                               aten::to         0.00%     422.000us        61.15%       17.016s     347.255ms       0.000us         0.00%       7.269ms     148.347us            49  
                                         aten::_to_copy         0.00%     294.000us        61.15%       17.015s        1.418s       0.000us         0.00%       7.271ms     605.917us            12  
                                            aten::copy_         0.00%     357.000us        61.15%       17.015s        1.418s       7.271ms         0.03%       7.271ms     605.917us            12  
                                  cudaStreamSynchronize        61.13%       17.011s        61.14%       17.011s        1.215s       0.000us         0.00%       0.000us       0.000us            14  
                                        cudaMemcpyAsync        30.64%        8.524s        30.64%        8.524s     608.869ms       0.000us         0.00%       0.000us       0.000us            14  
                                             aten::item         0.00%      41.000us        30.62%        8.521s        2.130s       0.000us         0.00%       2.000us       0.500us             4  
                              aten::_local_scalar_dense         0.00%      70.000us        30.62%        8.521s        2.130s       2.000us         0.00%       2.000us       0.500us             4  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         7.03%        1.957s         7.03%        1.957s     326.131ms       0.000us         0.00%       0.000us       0.000us             6  
autograd::engine::evaluate_function: ConvolutionBack...         0.00%     548.000us         0.10%      27.778ms     925.933us       0.000us         0.00%       20.901s     696.699ms            30  
                                   ConvolutionBackward0         0.00%     355.000us         0.10%      27.230ms     907.667us       0.000us         0.00%       20.901s     696.699ms            30  
                             aten::convolution_backward         0.04%      12.381ms         0.10%      26.875ms     895.833us       20.791s        80.24%       20.901s     696.699ms            30  
                                       cudaLaunchKernel         0.09%      24.595ms         0.09%      24.595ms       6.694us       0.000us         0.00%       0.000us       0.000us          3674  
                                           aten::conv2d         0.00%     208.000us         0.08%      21.959ms     731.967us       0.000us         0.00%        4.032s     134.395ms            30  
                                      aten::convolution         0.00%     788.000us         0.08%      21.751ms     725.033us       0.000us         0.00%        4.032s     134.395ms            30  
                                     aten::_convolution         0.00%     800.000us         0.08%      20.963ms     698.767us       0.000us         0.00%        4.032s     134.395ms            30  
                                aten::cudnn_convolution         0.03%       8.871ms         0.07%      18.846ms     628.200us        3.809s        14.70%        3.809s     126.953ms            30  
                                Optimizer.step#SGD.step         0.01%       1.959ms         0.01%       2.471ms     411.833us       0.000us         0.00%       5.692ms     948.667us             6  
                                              aten::sum         0.00%       1.216ms         0.01%       1.643ms      45.639us     110.117ms         0.43%     110.117ms       3.059ms            36  
     autograd::engine::evaluate_function: ReluBackward0         0.00%     300.000us         0.01%       1.636ms      54.533us       0.000us         0.00%     319.620ms      10.654ms            30  
    autograd::engine::evaluate_function: AddmmBackward0         0.00%     211.000us         0.01%       1.568ms     261.333us       0.000us         0.00%       4.123ms     687.167us             6  
                                             aten::relu         0.00%     483.000us         0.01%       1.486ms      49.533us       0.000us         0.00%     220.201ms       7.340ms            30  
                                          ReluBackward0         0.00%     223.000us         0.00%       1.336ms      44.533us       0.000us         0.00%     319.620ms      10.654ms            30  
autograd::engine::evaluate_function: torch::autograd...         0.00%     486.000us         0.00%       1.198ms      16.639us       0.000us         0.00%       0.000us       0.000us            72  
                               aten::threshold_backward         0.00%     734.000us         0.00%       1.113ms      37.100us     319.620ms         1.23%     319.620ms      10.654ms            30  
                                             aten::add_         0.00%     709.000us         0.00%       1.097ms      36.567us     223.258ms         0.86%     223.258ms       7.442ms            30  
autograd::engine::evaluate_function: MaxPool2DWithIn...         0.00%     178.000us         0.00%       1.057ms      88.083us       0.000us         0.00%     328.911ms      27.409ms            12  
                                           aten::linear         0.00%      59.000us         0.00%       1.045ms     174.167us       0.000us         0.00%       2.205ms     367.500us             6  
                                         AddmmBackward0         0.00%     139.000us         0.00%       1.029ms     171.500us       0.000us         0.00%       4.059ms     676.500us             6  
                                        aten::clamp_min         0.00%     651.000us         0.00%       1.003ms      33.433us     220.201ms         0.85%     220.201ms       7.340ms            30  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 27.825s
Self CUDA time total: 25.909s

I need a way to cut the copy and synchronize time.

I thought to send optim or loss function to the GPU, this way when I do backward and step there was no need to send data from CPU to GPU, eliminating the synchronize time and a bit of copy time.
There’s no way to send the DataLoader to the GPU as there is the ParallelLoader for TPU, so a bit of copying on X and y to CUDA will always exist.

Which might not be the bottleneck assuming the DataLoader is fast enough in preloading samples in the background.

Check which operations are synchronizing and remove them if possible. E.g.:

if (batch % 5 == 0):
    print(f'Batch: {batch}, Loss: {loss}')

will synchronize the code since you are printing a CUDATensor.
You can use torch.cuda.set_sync_debug_mode(debug_mode) to narrow down more operations.

Unrelated to this issue, but using nn.Softmax with nn.CrossEntropyLoss is wrong as the latter expects raw logits. Remove the nn.Softmax and the training might already improve.
Also, I would generally try to overfit a small dataset first (e.g. just 10 samples) before scaling the workload to the full dataset per epoch, which might be too slow.

I’ll try in a few hours when Google let me use their GPU again.

Now about the batch/loss print, aren’t they on CPU, since I only use .to(device) on X,y and model, or when I do backward on the optim it goes to the GPU?

About the Softmax, I tried without it and the error keep increasing until NaN, then I decided to look again at the documentation, CrossEntropyLoss — PyTorch 2.1 documentation, and it seems that only when using reduction=None that it’s the same as LogSoftman + NLLLoss. If so, I would need to use .mean on my loss before calling Cross Entropy without reduction.

PyTorch will not move tensors behind your back to the CPU. If you’ve moved the model and input to the GPU, the output will also be on the GPU. Since the target as well as the model output are on the GPU, the loss will be computed on the GPU, too.
Printing this tensor will move it back to the CPU.

No, since in all cases a log_softmax followed by F.nll_loss will be applied.

My time on Colab just got back, sorry for the delay.

I removed every print from the train loop, and cudaStream still takes 20 seconds, since there’s no print cudaMemcpyAsync is now gone, but DataLoader is taking 14 seconds, it shouldn’t be possible, since the cudaStreamSynchronize is taking too long, shouldn’t the workers be pre fetching my data?

You mean to only use 10 batch size per epoch? I didn’t fully understand your hint here. My images have 100 classes.

I don’t think the synchronization is caused by the DataLoader, which is why I suggested to profile the code with stacktraces to isolate the call.

No, I mean to use 10 samples in total to make sure your code doesn’t have any other issues while narrowing down why NaNs are shown during the training. This issue is unrelated to the first one.

Had a few issues with profiler returning an empty file, but now I got it. On CPU stack, there’s only one call to synchronize:
(318):_exit;torch/cuda/init.py(773):_synchronize;<built-in_function__cuda_synchronize> 38
I’m not sure how to upload both txt files, but I guess that won’t be needed since there’s only one call for synchronize.

I guess that because I run it a few times before getting the stack file, the cudnn benchmark made the code a bit faster, this is the current profile return:


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        12.16%        6.960s        99.95%       57.193s       57.193s       0.000us         0.00%        5.125s        5.125s             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        62.31%       35.652s        62.31%       35.652s        5.942s       0.000us         0.00%       0.000us       0.000us             6  
                                               aten::to         0.00%     167.000us        25.43%       14.549s     296.918ms       0.000us         0.00%       7.100ms     144.898us            49  
                                         aten::_to_copy         0.00%     240.000us        25.43%       14.549s        1.212s       0.000us         0.00%       7.100ms     591.667us            12  
                                            aten::copy_         0.00%     353.000us        25.42%       14.548s        1.212s       7.100ms         0.02%       7.100ms     591.667us            12  
                                  cudaStreamSynchronize        25.42%       14.544s        25.42%       14.544s        1.212s       0.000us         0.00%       0.000us       0.000us            12  
                                           aten::conv2d         0.00%     254.000us         0.04%      22.558ms     751.933us       0.000us         0.00%        4.799s     159.961ms            30  
                                       cudaLaunchKernel         0.04%      22.437ms         0.04%      22.437ms       6.117us       0.000us         0.00%       0.000us       0.000us          3668  
                                      aten::convolution         0.00%     718.000us         0.04%      22.304ms     743.467us       0.000us         0.00%        4.799s     159.961ms            30  
autograd::engine::evaluate_function: ConvolutionBack...         0.00%     510.000us         0.04%      22.024ms     734.133us       0.000us         0.00%       23.796s     793.206ms            30  
                                     aten::_convolution         0.00%     829.000us         0.04%      21.586ms     719.533us       0.000us         0.00%        4.799s     159.961ms            30  
                                   ConvolutionBackward0         0.00%     356.000us         0.04%      21.514ms     717.133us       0.000us         0.00%       23.796s     793.206ms            30  
                             aten::convolution_backward         0.02%       8.871ms         0.04%      21.158ms     705.267us       23.685s        79.99%       23.796s     793.206ms            30  
                                aten::cudnn_convolution         0.02%       9.068ms         0.03%      19.490ms     649.667us        4.566s        15.42%        4.566s     152.184ms            30  
                                        cudaMemcpyAsync         0.01%       3.735ms         0.01%       3.735ms     311.250us       0.000us         0.00%       0.000us       0.000us            12  
                                Optimizer.step#SGD.step         0.00%       2.425ms         0.01%       2.954ms     492.333us       0.000us         0.00%       5.664ms     944.000us             6  
                                             aten::relu         0.00%     561.000us         0.00%       1.555ms      51.833us       0.000us         0.00%     218.355ms       7.279ms            30  
                                              aten::sum         0.00%       1.087ms         0.00%       1.542ms      42.833us     111.566ms         0.38%     111.566ms       3.099ms            36  
    autograd::engine::evaluate_function: AddmmBackward0         0.00%     178.000us         0.00%       1.417ms     236.167us       0.000us         0.00%       4.599ms     766.500us             6  
     autograd::engine::evaluate_function: ReluBackward0         0.00%     248.000us         0.00%       1.342ms      44.733us       0.000us         0.00%     317.491ms      10.583ms            30  
                                          ReluBackward0         0.00%     167.000us         0.00%       1.094ms      36.467us       0.000us         0.00%     317.491ms      10.583ms            30  
                                           aten::linear         0.00%      59.000us         0.00%       1.074ms     179.000us       0.000us         0.00%       2.466ms     411.000us             6  
                                    cudaStreamWaitEvent         0.00%       1.015ms         0.00%       1.015ms       0.699us       0.000us         0.00%       0.000us       0.000us          1452  
autograd::engine::evaluate_function: torch::autograd...         0.00%     437.000us         0.00%       1.009ms      14.014us       0.000us         0.00%       0.000us       0.000us            72  
                                             aten::add_         0.00%     635.000us         0.00%       1.004ms      33.467us     233.315ms         0.79%     233.315ms       7.777ms            30  
                                        aten::clamp_min         0.00%     660.000us         0.00%     994.000us      33.133us     218.355ms         0.74%     218.355ms       7.279ms            30  
autograd::engine::evaluate_function: MaxPool2DWithIn...         0.00%     139.000us         0.00%     934.000us      77.833us       0.000us         0.00%     366.472ms      30.539ms            12  
                      Optimizer.zero_grad#SGD.zero_grad         0.00%     932.000us         0.00%     932.000us     155.333us       0.000us         0.00%       0.000us       0.000us             6  
                               aten::threshold_backward         0.00%     610.000us         0.00%     927.000us      30.900us     317.491ms         1.07%     317.491ms      10.583ms            30  
                                         AddmmBackward0         0.00%     155.000us         0.00%     903.000us     150.500us       0.000us         0.00%       4.530ms     755.000us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 57.221s
Self CUDA time total: 29.609s

You are seeing 12 calls to cudaStreamSynchronize.

Oh, you’re right, that’s the only reference I found on the txt. I’m sending the link so you can check both the cuda and cpu stack. Debug - Google Drive
I had to use experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True), because my torch version is only creating an empty file.

DataLoader is taking more than twice the time for the cudaStreamSynchronize.
I could try to increase the prefetch_factor, but I doubt it would be the case, since I only have 2 workers.

Edit:
Testing cudaStreamSynchronization with different batches
only 1 batch: 1.2ms
2 batches: 4.599s
So there’s something between first and second batch that’s causing Syncrhonization, but there shouldn’t be anything moved from GPU → CPU , only CPU → GPU (input data).

Solved**
DataLoader is taking a lot of time, but there’s no way to know if some of the show time is parallel to the GPU computing time, hence don’t slow the inference as a whole.

Edit2:
Since the DataLoader is taking too long, wasn’ it better to just load all the images as tensor on memory? Train dataset is about 350Mb of images. That little memory could even be stored on the GPU instead CPU, then I wouldn’t even need to .to(device) every batch.
What do you think, @ptrblck
**
Edit3:
Instead of using the raw files from Drive, I just unziped it on VM and now with 32 batch_size during 10 batches the time used by DataLoader was only 236ms.
All attention back to fixing the Synchronize time!