Pass multiple parameters into DistributedDataParallel

What is the right way to use DistributedDataParallel when we have several things to optimize?
I have a model that I am wrapping into DistributedDataParallel
But I also have another class with some parameters to optimize during training. Do I simply wrap it as well into DistributedDataParallel? So that final script looks like:

if distributed:
   model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
   other_class = torch.nn.parallel.DistributedDataParallel(other_class , device_ids=[device])

DistributedDataParallel expects an nn.Module as its first argument so you might need to create a custom module containing your parameters if you want to use them in DDP.

Thanks for your reply!
Yes that makes sense - I have indeed created two classes that inherit from nn.Module:

One for the model

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # some code here

and one for multi-task losses

class Weightloss(nn.Module):

    def __init__(self):
        super(Weightloss, self).__init__()
        # some code here

    def forward(self, x):
        # some code here

So if I then pass them both in two different torch.nn.parallel.DistributedDataParallel like below:

model = MyModel()
w_loss = Weightloss()
if distributed:
   model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
   w_loss = torch.nn.parallel.DistributedDataParallel(w_loss , device_ids=[device])

Is it the right way of doing it and it will be properly synchronized across all the GPUs? Or these classes should be passed somehow to only one torch.nn.parallel.DistributedDataParallel? Something like:

if distributed:
   model, w_loss = torch.nn.parallel.DistributedDataParallel(model, w_loss, device_ids=[device])

(which I guess will throw some error if I initialize it this way)

Yes, I think you are right as DDP expects a single module as its input argument.

Probably the safest approach would be to initialize Weightloss inside your MyModel which would allow you to only pass the MyModel object to DDP.
If that’s not possible you could try to use two separate DDP objects, which I would assume should also work as they should be executed sequentially.

1 Like