Why use 4 GPU with DataParallel strangely slow the speed of inference compared with no DataParallel?

I’m using resnet-18 to infer test set. When I use one GPU, the time is only 10s. But when I want to use multiple GPUs and only add this code

if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  model = torch.nn.DataParallel(model)

The result look like this
which cost 24s.
My whole code is below.

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
import time
model = torch.load('resnet_18-cifar_10.pth')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#if torch.cuda.device_count() > 1:
#  print("Let's use", torch.cuda.device_count(), "GPUs!")
#  model = torch.nn.DataParallel(model)
lossSum = 0.0
preLoss = 0.0
curLoss = 0.0
accuracy = 0.0
correctNum = 0
data_transforms = transforms.Compose([
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
image_datasets = datasets.CIFAR10(root='../data', train=False,
                                  download=True, transform=data_transforms)
test_loader = torch.utils.data.DataLoader(image_datasets, batch_size=500,
                                          shuffle=True, num_workers=4, prefetch_factor=2)
since = time.time()
for data, target in test_loader:
    data = data.to(device)
    target = target.to(device)
    output = model(data)
    lossSum = lossSum + F.cross_entropy(output, target, reduction='sum').item()
    pred = output.data.max(1, keepdim=True)[1]  
    correctNum += pred.eq(target.data.view_as(pred)).cpu().sum() 
curLoss = lossSum / len(test_loader.dataset
accuracy = correctNum / len(test_loader.dataset)
time_elapsed = time.time() - since
print('Test complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
1 Like

nn.DataParallel will scatter and gather the parameters, gradients, etc. between the devices as explained in this blog post and is thus adding some overhead. However, it should not slow down the execution assuming you’ve scaled up the batch size.

We generally recommend to use DistributedDataParallel with a single process per device for the best performance.