nn.DataParallel doesn't speed up forward

Hi everyone, I’m using nn.DataParallel to do multi GPU training. But it’s strange that I didn’t get any speedup on forward nor backward via DataParallel. Is there anything I miss when using DataParallel?
My code for profiling forward is as follows

import torch
import torch.nn as nn
import time

DIM = 128
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.preprocess = nn.Sequential(
            nn.Linear(128, 4 * 4 * 4 * DIM),
            nn.BatchNorm1d(4 * 4 * 4 * DIM),
            nn.ReLU(True),
        )

        self.main_module = nn.Sequential(
            nn.ConvTranspose2d(
                4 * DIM, 2 * DIM, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(2 * DIM),
            nn.ReLU(True),
            nn.ConvTranspose2d(2 * DIM, DIM, kernel_size=4,
                               stride=2, padding=1),
            nn.BatchNorm2d(DIM),
            nn.ReLU(True),
            nn.ConvTranspose2d(DIM, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.preprocess(input)
        output = output.view(-1, 4 * DIM, 4, 4)
        output = self.main_module(output)
        return output.view(-1, 3, 32, 32)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main_module = nn.Sequential(
            nn.Conv2d(3, DIM, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            # 16x16
            nn.Conv2d(DIM, 2 * DIM, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            # 8x8
            nn.Conv2d(2 * DIM, 4 * DIM, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            # 4 x 4
        )
        self.linear = nn.Linear(4 * 4 * 4 * DIM, 1)

    def forward(self, input):
        output = self.main_module(input)
        output = output.view(-1, 4 * 4 * 4 * DIM)
        output = self.linear(output)
        return output


if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(device)
    print(torch.__version__)
    batch_size = 256
    print('====Single GPU test====')
    D = Discriminator().to(device)
    G = Generator().to(device)
    data = (torch.rand((batch_size, 3, 32, 32), device=device) - 0.5) / 0.5
    z = torch.randn((batch_size, 128), device=device)
    for i in range(2):
        torch.cuda.synchronize()
        start = time.time()
        loss = D(data) - D(G(z))
        torch.cuda.synchronize()
        end = time.time()
        if i != 0:# skip reporting for the first iteration because of cudnn.benchmark
            print('Iter: %d; Forward time cost: %.6fs' % (i, end - start))

    print('====Two GPUs test====')
    D2 = Discriminator().to(device)
    G2 = Generator().to(device)
    D2 = nn.DataParallel(D2, list(range(2)))
    G2 = nn.DataParallel(G2, list(range(2)))
    for i in range(2):
        torch.cuda.synchronize()
        start = time.time()
        loss = D(data) - D(G(z))
        torch.cuda.synchronize()
        end = time.time()
        if i != 0: # skip reporting for the first iteration because of cudnn.benchmark
            print('Iter: %d; Forward time cost: %.6fs' % (i, end - start))

Output is

cuda:0
1.7.0+cu101
====Single GPU test====
Iter: 1; Forward time cost: 0.013395s
====Two GPUs test====
Iter: 1; Forward time cost: 0.012617s

DataParallel doesn’t seem to speed up the forward pass.

Hey @Hongkai_Zheng, it’s possible that DataParallel's overhead overshadows its benefits. In the forward function of DataParallel, it would replicate the model, scatter input, run parallel_apply, and gather outputs in every iteration. Besides, as DataParallel is multi-thread parallel, different threads need to compete for Python GIL. To verify if this is the case, you can try larger batch_size to make the per-GPU forward computation more expensive and check if parallelizing more expensive forward could make the speedup more visible. If tuning batch_size is not a good option for you, I would suggesting using DistributedDataParallel. See this overview: PyTorch Distributed Overview — PyTorch Tutorials 1.7.1 documentation

1 Like

Hi Shen, thanks for your explanation. Indeed, benefits outweigh overheads when batch size goes to 512, which is not a practical option for me though.

Regarding Pytorch DDP, my code currently relies on autograd.grad() to compute Hessian vector product. But as is said in DDP doc, DDP doesn’t work with autograd.grad(). Is there any way to do Hessian vector product so that we can utilize Pytorch distributed parallel training? It seems to me that nn.DataParallel is the only parallel technique I can use so far. I’m also wondering how much difference is there between DDP and DataParallel regarding training speedup.

cc autograd expert @albanD for autograd and Hessian vector product questions :slight_smile:

I’m also wondering how much difference is there between DDP and DataParallel regarding training speedup.

This depends on things like model sizes, number of GPUs, GPU interconnects, etc. DDP can avoid replicating models in every iteration and avoid GIL contention. These overhead might weigh differently in different applications.

Hi,

You can use a function similar to the following to do the same thing as autograd.grad but with backward (require nightly build as of writting).
It will be more expensive than a vanilla autograd.grad but should be fairly similar.

def my_autograd_grad(outputs, inputs, grad_outputs, create_graph=False):
    # Save existing .grad if any
    grads = []
    for i in inputs:
        grads.append(i.grad)
        del i.grad
    
    # Do the backward
    autograd.backward(outpus, grad_outputs, inputs=inputs, create_grad=create_graph)

    # Get the result
    res_grad = tuple(i.grad for i in inputs)

    # Restore previous gradients
    for i, g in zip(inputs, grads):
        i.grad = g

    return res_grad
1 Like

Cool! Minor questions: why do we need to delete i.grad before backward() and restore it afterwards? Is this where “more expensive” comes from?

Oh, is that because we do a backward() to compute gradient before Hessian vector product, and i.grad are kept so that DDP can gather gradient across the devices? :thinking:

That is because the backward accumulate into the .grad field. But you might already need these for other reasons so you have to save/restore them.

1 Like