Issue with using DataParallel (includes minimal code)

I have the following (minimal) code that runs on GPU and I’m trying to run it in multiple GPUs using nn.DataParallel:

import math
import torch
import pickle
import time

import numpy as np
import torch.optim as optim

from torch import nn

print('device_count()', torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print('get_device_name', torch.cuda.get_device_name(i))

def _data(dimension, num_examples):
    num_mislabeled_examples = 20

    ground_truth_weights = np.random.normal(size=dimension) / math.sqrt(dimension)
    ground_truth_threshold = 0

    features = np.random.normal(size=(num_examples, dimension)).astype(
        np.float32) / math.sqrt(dimension)
    labels = (np.matmul(features, ground_truth_weights) >
              ground_truth_threshold).astype(np.float32)
    mislabeled_indices = np.random.choice(
        num_examples, num_mislabeled_examples, replace=False)
    labels[mislabeled_indices] = 1 - labels[mislabeled_indices]

    return torch.tensor(labels), torch.tensor(features)

class tools:
    def __init__(self):
        self.name = 'x_2'

    def SomeFunc(self, model, input_):
        print(model.first_term(input_)[0])        


class predictor(nn.Module):
    def __init__(self, dim):
        super(predictor, self).__init__()
        self.weights = torch.nn.Parameter(torch.zeros(dim, 1, requires_grad=True))
        self.threshold = torch.nn.Parameter(torch.zeros(1, 1, requires_grad=True))

    def first_term(self, features):
        return features @ self.weights

    def forward(self, features):
        return self.first_term(features) - self.threshold

class HingeLoss(nn.Module):

    def __init__(self):
        super(HingeLoss, self).__init__()
        self.relu = nn.ReLU()

    def forward(self, output, target):
        all_ones = torch.ones_like(target)
        labels = 2 * target - all_ones
        losses = all_ones - torch.mul(output.squeeze(1), labels)

        return torch.norm(self.relu(losses))

class function(object):
    def __init__(self, epochs):

        dim = 10
        N = 100
        self.target, self.features = _data(dim, N)

        self.epochs = epochs 
        self.model = predictor(dim).to('cuda')
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3)
        self.target = self.target.to('cuda')
        self.features = self.features.to('cuda')
        self.loss_function = HingeLoss().to('cuda')

        self.tools = tools()

    def train(self):

        self.model.train()
        for epoch in range(self.epochs):
            self.optimizer.zero_grad()
            output = self.model(self.features)
            # self.tools.SomeFunc(self.model, self.features)
            print(output.is_cuda)
            loss = self.loss_function(output, self.target)
            loss.backward()
            print('For epoch {}, loss is: {}.'.format(epoch, loss.item()))
            self.optimizer.step()

def main():
    model = function(1000)
    print(torch.cuda.device_count())
    if False: # This is Flag
        if torch.cuda.device_count() > 1:
            model.model = nn.DataParallel(model.model)
    t = time.time()
    model.train()
    print('elapsed: {}'.format(time.time() - t))

if __name__ == '__main__':
    main()
  1. I have 4 GPU cards (device_count = 4). When I set the flag (indicated with the comment This is Flag) to True, it takes 15.78 seconds to run the code. When I set it to False, it takes 0.71 seconds. Why? How could it be fixed?

  2. When I uncomment the line self.tools.SomeFunc(self.model, self.features) and set the flag to True, I receive the following error:

AttributeError: ‘DataParallel’ object has no attribute ‘first_term’

How should I fix this? Thanks!

One thing is that you will need a torch.cuda.synchronize() before calling time.time() to make sure all pending CUDA kernals in the stream are finished.

You can also use elapsed_time to measure. See discussion here.

If you are looking for the most performant solution, DistributedDataParallel should be the way to go. [example]

When I uncomment the line self.tools.SomeFunc(self.model, self.features) and set the flag to True , I receive the following error:

Looks like self.model is a DataParallel instance? If so, DataParallel does not have the first_term attribute. If this attribute is on the model instance you passed to DataParallel, you can access the original model instance through self.model.module (see DataParallel code here) which should have the first_term attribute.

Regarding the first part, calling torch.cuda.synchronize() did not help. It still seems that using 4 GPUs instead of 1 makes the code 15 times slower.

After fixing the second part, my output for

print(model.module.first_term(input_)[0])

is always on the first worker:

tensor([0.0020], device='cuda:0', grad_fn=<SelectBackward>)

So it is not even sharing the work between the different worker.

What is “worker” here? Do you mean GPU? If so, isn’t that expected? IIUC, in your code, model is the DataParallel instance. So only the forward function of model would utilize multiple GPUs. See the data parallel implementation below:

1 Like

Regarding the slowdown, I can reproduce it locally with two GPUs:

Using DataParallel on 2 GPUs

For epoch 999, loss is: 8.4603910446167.
elapsed: 3.1627652645111084

Not using DataParallel

For epoch 999, loss is: 8.323615074157715.
elapsed: 1.192000389099121

Then I go back to inspect the model:

    def __init__(self, dim):
        super(predictor, self).__init__()
        self.weights = torch.nn.Parameter(torch.zeros(dim, 1, requires_grad=True))
        self.threshold = torch.nn.Parameter(torch.zeros(1, 1, requires_grad=True))

My suspicion is that the model parameters are so light that the overhead of GIL contention, input scattering, output gathering, and model replication in DataParallel forward pass overshadows the speedup brought by multi-gpu training. Are these parameters used in real applications or are you trying to profile DataParallel performance?

1 Like

I didn’t quite get your point on this one.

That makes sense. Let me try it on a more sophisticated code and update the answer.

I’m just making sure I understand how to use DataParallel correctly. Although way simpler, this sample code mimics the general structure of my actual code. So, there is nothing wrong with my implementation?

Thanks, based on your response I used this wrapper to get access to attributes without altering the code itself.

I might have misunderstood the original question.

After fixing the second part, my output for

print(model.module.first_term(input_)[0])

is always on the first worker:

tensor([0.0020], device='cuda:0', grad_fn=<SelectBackward>)

Is the question about why the output of print(model.module.first_term(input_)[0]) always on cuda:0?

1 Like

Yes: I’m under impression that there are 2 ways of parallelizing a PyTorch code: DistributedDataParallel and DataParallel. In the former each layer of the network is assigned to a particular processor while in the latter, each processor takes a portion of the training data and all the processors go through all the code (like here). Although DistributedDataParallel is preferred (though I’m not sure why, except for multi-node processing, perhaps?), it looks hairy and I decided to start with DataParallel. Hence, I expected all the processors call first_term() when they get to that part of the code. What am I missing?

This is mostly due to performance reasons. As of today, DataParallel (DP) replicates model in its forward pass, while DistributedDataParallel (DDP) replicates models in its ctor. That means DP would replicate model once in every iteration. Besides, DP also suffers from GIL contention as it is single-process-multi-thread. DDP does not hit this problem, as each model replica runs in its own process. More info about DDP can be found here and here.

Hence, I expected all the processors call first_term() when they get to that part of the code. What am I missing?

What happens in DP’s forward function is: 1) replicate model to all devices 2) scatter inputs to all devices 3) launch multiple threads in parallel, where each threads processes an input split using one model replica on one device 4) gather outputs to the same device.

Given the above, if you change the predictor code to the following. You will see it prints multiple devices.

class predictor(nn.Module):
     ....

    def forward(self, features):
        print(self.first_term(features).device)
        return self.first_term(features) - self.threshold

However, for the following code:

    def SomeFunc(self, model, input_):
        print(model.first_term(input_)[0])

If it is called outside of a forward pass or if the model argument is not a model replica (the self argument in predictor.forward method), then it won’t show different devices.

1 Like

So I tried on another code, with 1 GPU my code ran in 434 sec while with 2 GPUs it took 864 sec. So it shouldn’t be from the price we pay for parallelization. Also, using your line print(self.first_term(features).device) it uses all processors at each step so the code is not running in series by each GPU.

Can we profile how much of the 434s are spent in the forward pass when DP is not present? And how much of that is spent on GPU? This can be measured using elapsed_time . See this discussion.

Note that multi-thread cannot parallelize normal Python ops due to Python GIL, and the parallelism only kicks in when the execution does not require GIL (e.g., CPU/GPU ops that explicitly drops GIL).

1 Like

Not sure if I’m doing right. Am I doing it right for the sample code below?

def main():
    model = function(1000)
    print(torch.cuda.device_count())
    if True:
        if torch.cuda.device_count() > 1:
            model.model = MyDataParallel(model.model)

    start = time.monotonic()
    s = torch.cuda.current_stream()
    e_start = torch.cuda.Event(enable_timing=True)
    e_finish = torch.cuda.Event(enable_timing=True)
    s.record_event(e_start)

    model.train()

    torch.cuda.synchronize()
    s.record_event(e_finish)
    e_finish.synchronize()
    end = time.monotonic()

    print('Forward latency is: {} '.format(e_start.elapsed_time(e_finish)))
    print("end - start = ", end - start)

if __name__ == '__main__':
    main()

This looks correct to me. Does MyDataParallel use DataParallel internally?

1 Like

Yes, it’s just a wrapper so that I can access the attributes (so I don’t need to add .module). If so, below is the output for my actual code:

Forward latency is: 437033.9375
end - start = 437.031393879035

So what do you think is wrong with the parallel implementation that takes double as much time to run with two GPUs?

This is surprising to me. I would assume at least the CUDA ops can run in parallel. Could you please share the full code used in this test? We will investigate.

1 Like

Sure, I can share it privately. Could you please send me your Github id?

Thanks, it’s mrshenli.

1 Like

Done, let me know if you received it.

Got it and will investigate. Thanks for sharing.

1 Like

Hello @mrshenli, I was wondering if you have any comments yet as to why this happens? Thanks.