Time to load training batch to GPU varying with model size

It looks like the time to copy one batch of data from cpu to gpu varies according to model size or maybe inference time. If model size or inference time is large, time taken is larger. I can’t understand why I am seeing this behavior.

Below is the code to reproduce and results:

import time
import torch
import torch.utils.data as data_utils
import torchvision.models as models

train_data = torch.randn(800, 3, 640, 640)
train_labels = torch.ones(800).long()

train = data_utils.TensorDataset(train_data, train_labels)
train_loader = data_utils.DataLoader(train, batch_size=2, shuffle=True)

#model = models.densenet161().cuda() # OPTION 1
model = models.resnet18().cuda() # OPTION 2
for x,y in train_loader:
    st = time.time()
    x, y = x.cuda(), y.cuda()
    print(time.time()-st)
    
    pred = model(x)

As you can see the code prints out the time taken to copy x and y to GPU.

Resnet :

0.0025365352630615234
0.006475925445556641
0.009412765502929688
0.008988618850708008
0.009799957275390625
0.009394407272338867
0.009585857391357422
0.00932931900024414
0.009220361709594727

Densenet :

0.046151161193847656
0.04446148872375488
0.041242122650146484
0.03512907028198242
0.03663516044616699
0.03776717185974121
0.03576469421386719
0.03654170036315918
0.03631234169006348
0.03648805618286133

Also the architectures need not really be different to reproduce this.
Example: Just do the forward pass multiple times on resnet:

model = models.resnet18().cuda()
for x,y in train_loader:
    st = time.time()
    x, y = x.cuda(), y.cuda()
    print(time.time()-st)
    
    pred = model(x)
    pred = model(x)
    pred = model(x)
    pred = model(x)

Outputs:

0.03314852714538574
0.03551316261291504
0.03101205825805664
0.031397342681884766
0.03133535385131836
0.03079056739807129
0.025241851806640625
0.029177427291870117
0.029461145401000977
0.028394699096679688

So more than 10 times compared to when you do just one inference.

On a different machine with a different GPU but the same pytorch version (1.3.1), this behavior is not reproducible with batch_size=1, but it is reproducible with batch_size=2. Also the input size seems to matter.

With

train_data = torch.randn(800, 3, 224, 224)

and batch_size=4 this behavior is not reproducible but with batch_size=16 it is.

I’m not really sure what is happening here. Even if this is related to asynchronicity shouldn’t a larger inference time ensure that the data is already copied in the background?

PyTorch version used is 1.3.1

CUDA operations are asynchronous, so you should synchronize via torch.cuda.synchronize() before starting and stopping the timer.
Most likely, your code waits at the x = x.cuda() calls, while the model is still running in the background, which will accumulate the actual model time into the time to transfer the data.

So you mean to say that x and y from the next batch will not be loaded to GPU while model inference is happening.
It’s just that with torch.cuda.synchronize() python will wait at this line for model inference to complete and then go and execute st = time.time(), while without it, it goes to x.cuda() and waits there.

If that’s true, is there a way to load the next batch while model inference is going on? That’s what I’m trying to do.

You could try to use pinned memory and apply to('cuda', non_blocking=True).
If the GPU is busy, the transfer will be added to the queue.
If you are concerned about the best peak performance, you could have a look at this prefetcher, which I would not recommend for a general use case.

1 Like

So I was able to use that prefetcher with my original case and it has increased performance quite a bit.
However could you clarify why you don’t recommend it for general use case?

Mainly, because using multiple workers on the DataLoader should be sufficient and this (advanced) approach might yield negative side effects, e.g. if you use other streams inside your code.

That makes sense. Thanks for your help!

Hi @ptrblck,

I’m suspecting that my model has a bottleneck in loading data, is below the correct way to time the operations?

for epoch in range(n_epoch):
    epoch_tic = time.time()
    for x_batch, y_batch in training_loader:
        batch_load_tic = time.time()
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        batch_load_toc = time.time()

        nnet = nnet.train()
        yhat = nnet(x_batch)
        loss = loss_fn(yhat, y_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_toc = time.time()
    epoch_toc = time.time()

Because I’ve been training models on a GPU but lately the jobs became mysteriously slow for some reason. I wanted to see which part of the training process is taking longer than expected.

What I observe is

epoch_toc - epoch_tic = 45 min
batch_load_toc - batch_load_tic = 0.005s
update_toc - batch_load_toc = 0.0423s

One epoch in my case contains 250 batches, but (0.005+0.0423)*250 is only about 12 minutes. So I’m wondering where are the remaining 45-12=33 minutes and is there a way to time it?

Thanks!

Your current script tries to profile the host to device transfer time, which wouldn’t work, since CUDA operations are asynchronous, so you would need to synchronize the code via torch.cuda.synchronize() before starting and stopping the timer.
To profile the data loading, you could use the data_time object from the ImageNet example.

Hi @ptrblck,

Thank you so much for your reply! If I only care about

  1. time to finish one epoch
  2. time of entire mini-batch
  3. time of loading each mini-batch

then I don’t need to use torch.cuda.synchronize(), and instead I can just modify the code as

for epoch in range(n_epoch):
    epoch_tic = time.time()
    tic = time.time()
    for x_batch, y_batch in training_loader:
        
        batch_load_time = time.time() - tic
        
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        nnet = nnet.train()
        yhat = nnet(x_batch)
        loss = loss_fn(yhat, y_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        batch_time = time.time() - tic
        tic = time.time()
    
    epoch_time = time.time() - epoch_tic

And correspondingly,

  1. epoch_time is the time to finish one epoch
  2. batch_load_time is the time to load one mini-batch of data
  3. batch_time is the time to finish the training of one mini-batch

Is this correct?

batch_load_time should give you the latency of the DataLoader, so it should approximate zero, if the DataLoader is fast enough to preload the next batches.

However, epoch_time and batch_time will not give you an accurate result without synchronizations, as the GPU might be busy while the host code stops the timer.
To get valid profiles, you would have to add synchronizations.

Hi @ptrblck,

Thank you for your reply!

So the correct way should be adding torch.cuda.synchronize() in front of every time.time() call right?

for epoch in range(n_epoch):
    torch.cuda.synchronize()
    epoch_tic = time.time()
    tic = time.time()
    for x_batch, y_batch in training_loader:
        torch.cuda.synchronize()
        batch_load_time = time.time() - tic
        
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        
        nnet = nnet.train()
        yhat = nnet(x_batch)
        loss = loss_fn(yhat, y_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
        batch_time = time.time() - tic
        tic = time.time()
    torch.cuda.synchronize()
    epoch_time = time.time() - epoch_tic

Yes, this would give you the accurate timings. Once you have the profiles, you should remove it again to allow the asynchronous calls again.

Note that the first iterations (or each iterations using a new input shape) will be slower, if you are using torch.backends.cudnn.benchmark = True, as cudnn will run some benchmarking on the current workload.

Also, usually you would add some warmup iterations, but this might not be necessary if you just want to narrow down the bottleneck.