DataParallel doesn't work when calling model.module.some_attribute



So in the two images there are two different models model and model_p both being wrapped under nn.DataParallel. But in model when calling some attribute fit using the model.module method, I’m unable utilize the two GPUs I originally wanted to parallelize my model upon. i.e model doesn’t split the dim=0 batch_first dimension into two equal halves for putting it onto two devices as can be seen from the print statements.
Ps. I am very new to using DataParallel and wanted to use something like this. i.e What I actually want is, to call model.module.fit in my training loop with the args as the inputs from my dataloader and in this fit attribute ultimately will makes a call to the forward method of the class model.

But this whole thing doesn’t seem to parallelize and utilize the two GPUs which the model_p could without any fit function and a direct call to forward internally.

I’ve added the link to the notebook which was run with CUDA_VISIBLE_DEVICES=0,1

What should I change?
Thanks!

class Model(nn.Module):
    # Our model


    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        return output
    
    def fit(self, input):
        output = self.forward(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())
        return output

model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

model.to(device)

for data in rand_loader:
    input = data.to(device)
    output = model.module.fit(input)
    print("Outside: input size", input.size(),"output_size", output.size())

#############################CASE 2############################
class ModelParallel(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(ModelParallel, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())

        return output
    
model_p = ModelParallel(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model_p = nn.DataParallel(model_p)
#   model.module.fit = nn.DataParallel(model.module.fit)

model_p.to(device)

for data in rand_loader:
    input = data.to(device)
    output = model_p(input)
    print("Outside: input size", input.size(),"output_size", output.size())

DataParallel splits GPUs using its custom forward function and is implemented as a wrapper rather than a subclass which overrides the model’s forward. When you’re calling fit, you’re calling the forward() associated with the model and not the one wrapped around DataParallel. Hence it will only use a single gpu, as the scatter gather in DataParallel.forward(...) is never called.

From docs:

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError("module must have its parameters and buffers "
                                   "on device {} (device_ids[0]) but found one of "
                                   "them on device: {}".format(self.src_device_obj, t.device))

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = self.parallel_apply(replicas, inputs, kwargs)
        return self.gather(outputs, self.output_device)
2 Likes

Thanks for the answer @jerinphilip ! Makes sense why it only uses one gpu in the case of fit
So is there anything I can do which can help me in parallelizing the fit method? Or the only way to parallelize a model is to call the forward from dataparallel wrapped model itself?

You can use a flag keyword argument inside forward, noticing that the two functions don’t differ by much. I tried to switch member functions using a flag and the following segment worked for me:

In my case I’m switching whether to use generator, discriminator or critic in an GAN-Actor Critic setup. I’m using tag here to control which sub-model’s forward is being called.

You can see scatter's source code below to understand how args and kwargs are replicated along workers, in case there’s any confusion, which at the time I had:

Thanks a tonne for helping me out! :smile:
I found a simple way to change my code and use the parallel functionality in my forward.
Everything working as expected now.

Hi, I met the same problem and want to use multi gpus even with model.module.predict, predict is a part defined in the model class. So, could you tell me the simple way you @gollum found? Thanks!

1 Like