DataLoader CPU->GPU transfer use Pinned Memory with num_workers>0?

My current training bottleneck has been identified as data transfer to GPU. I noticed that when using num_workers>0 I seem to be hitting a wall where the GPU transfer rate appears to approach that of non-pinned data transfer, which is far below the rate of pinned data transfer. Based on benchmark example found here: Slow CPU<=>GPU transfer

import torch
import time
#import os
#os.environ["CUDA_VISIBLE_DEVICES"] = '0'
rank = 0
gpu_device = torch.device("cuda:"+str(rank))

Batch = 512
N = 25
C = 50
W = 500
n_samples = 50

w1 = [torch.randn(Batch, N, C, W, device=torch.device("cpu")) for x in range(n_samples)]
w1_pinned = [torch.nn.Parameter(x).pin_memory() for x in w1]
w1_not_pinned = [torch.nn.Parameter(x) for x in w1]
x1 = torch.randn(Batch, N, C, W, dtype=torch.float32, device=gpu_device)

#Run a Pinned Memory Transfer Test
torch.cuda.synchronize()
t_start = time.time()
for i in range(n_samples):
     x1.copy_(w1_pinned[i].data, non_blocking=True)
torch.cuda.synchronize()
t_end = time.time()
t_time = (t_end - t_start)*1000
t_size = n_samples*(Batch*N*C*W)*4/1024/1024
t_bw = t_size/(t_end-t_start)
print('Pinned Test')
print('size of transfer ', t_size, 'MB')
print('time taken by transfer ', t_time, 'mSec')
print('Effective bandwidth ', t_bw, 'MBps (Pinned)')

#Run a Non-Pinned Memory Transfer Test
torch.cuda.synchronize()
t_start = time.time()
for i in range(n_samples):
     x1.copy_(w1_not_pinned[i].data, non_blocking=True)
torch.cuda.synchronize()
t_end = time.time()
t_time = (t_end - t_start)*1000
t_size = n_samples*(Batch*N*C*W)*4/1024/1024
t_bw = t_size/(t_end-t_start)
print('Not Pinned Test')
print('size of transfer ', t_size, 'MB')
print('time taken by transfer ', t_time, 'mSec')
print('Effective bandwidth ', t_bw, 'MBps (Not Pinned)')

I get the following consistent results for V100:

Pinned Test
size of transfer 61035.15625 MB
time taken by transfer 5160.215616226196 mSec
Effective bandwidth 11828.024406204298 MBps (Pinned)
Not Pinned Test
size of transfer 61035.15625 MB
time taken by transfer 29578.11951637268 mSec
Effective bandwidth 2063.5238902261713 MBps (Not Pinned)

Now if I throw in a simple DataLoader to accomplish this transfer (as you would in a training loop):

from torch.utils.data import DataLoader

class pytorch_dataset(torch.utils.data.Dataset):
     def __init__(self, samples):
          self.samples = samples

     def __len__(self):
          return len(self.samples)

     def __getitem__(self, item):
          return self.samples[item]

def DataLoader_Test(torch_loader, num_workers):
    torch.cuda.synchronize()
    t_start = time.time()
    for x in torch_loader:
        sample = x.to('cuda', non_blocking=True)
    torch.cuda.synchronize()
    t_end = time.time()

    t_time = (t_end - t_start)*1000
    t_size = n_samples*(Batch*N*C*W)*4/1024/1024
    t_bw = t_size/(t_end-t_start)

    print('DataLoader Test: ', num_workers, ' workers')
    print('size of transfer ', t_size, 'MB')
    print('time taken by transfer ', t_time, 'mSec')
    print('Effective bandwidth ', t_bw, 'MBps (DataLoader:',num_workers,' workers)')

num_workers=0
torch_loader = DataLoader(pytorch_dataset(w1),
                          batch_size=None,
                          num_workers=num_workers,
                          pin_memory=True)
DataLoader_Test(torch_loader, num_workers)

num_workers=1
torch_loader = DataLoader(pytorch_dataset(w1),
                          batch_size=None,
                          num_workers=num_workers,
                          pin_memory=True)

DataLoader_Test(torch_loader, num_workers)

This is again the results for V100:

DataLoader Test: 0 workers
size of transfer 61035.15625 MB
time taken by transfer 5572.557687759399 mSec
Effective bandwidth 10952.80832786513 MBps (DataLoader: 0 workers)
DataLoader Test: 1 workers
size of transfer 61035.15625 MB
time taken by transfer 36907.78851509094 mSec
Effective bandwidth 1653.7202229021066 MBps (DataLoader: 1 workers)

Now as the number of samples is increased, the DataLoader overhead is diminished and the transfer rate approaches that of the benchmarks above. (more workers just binds things up as they are all accessing the same data in memory)

I have yet to look under the hood, but I assume when n_workers > 0 the results are passed into a multiprocessing queue. Does this queue retain data pinning?

If I check that the returned tensor is pinned via: x.is_pinned(), I get True in either case. But it seems strange that I can’t seem to achieve more than the non-pinned transfer rate.

I verified this is not a data-pinning issue. Ok to close this thread based on the title.

If the GPU is not used in the DataLoader test, the transfer rate for a single worker is still the same.

def DataLoader_Test(torch_loader, num_workers):
torch.cuda.synchronize()
t_start = time.time()
for x in torch_loader:
sample = x
#sample = x.to(‘cuda’, non_blocking=True)
torch.cuda.synchronize()
t_end = time.time()

t_time = (t_end - t_start)*1000
t_size = n_samples*(Batch*N*C*W)*4/1024/1024
t_bw = t_size/(t_end-t_start)

print('DataLoader Test: ', num_workers, ' workers')
print('size of transfer ', t_size, 'MB')
print('time taken by transfer ', t_time, 'mSec')
print('Effective bandwidth ', t_bw, 'MBps (DataLoader:',num_workers,' workers)')

Result:

DataLoader Test: 0 workers
size of transfer 244140.625 MB
time taken by transfer 8440.999269485474 mSec
Effective bandwidth 28923.18992166928 MBps (DataLoader: 0 workers)
DataLoader Test: 1 workers
size of transfer 244140.625 MB
time taken by transfer 127765.62309265137 mSec
Effective bandwidth 1910.847527608873 MBps (DataLoader: 1 workers)

I did look at the dataloader construction and it is loading a queue and then passing through a second queue if being pinned. It must be the queue throughput and pinning overhead that is limiting performance.

If the pin_memory is set to False, and again bypassing the GPU:

num_workers=0
torch_loader = DataLoader(pytorch_dataset(w1),
batch_size=None,
num_workers=num_workers,
pin_memory=False)
DataLoader_Test(torch_loader, num_workers)

num_workers=1
torch_loader = DataLoader(pytorch_dataset(w1),
batch_size=None,
num_workers=num_workers,
pin_memory=False)

DataLoader_Test(torch_loader, num_workers)

The results show bandwidth well above, where the single worker instance is passing through a single queue, but not the second pinning queue:

DataLoader Test: 0 workers
size of transfer 244140.625 MB
time taken by transfer 5.549430847167969 mSec
Effective bandwidth 43993813.36999484 MBps (DataLoader: 0 workers)
DataLoader Test: 1 workers
size of transfer 244140.625 MB
time taken by transfer 14370.611906051636 mSec
Effective bandwidth 16988.881656263326 MBps (DataLoader: 1 workers)