Multiple modules with distributed data parallel

Hi there,

I am trying to use DistributedDataParallel for multi-GPU use with multiple nn.Modules. But, my programme consists of two nn.Modules: 1. a model (containing model parameters) and 2. a constraint (containing Lagrangian parameters). They depend on the same backward pass, but have their own optimisers, something along the lines of:

model_optimiser = optimiser(model.parameters(), ...)
constraint_optimiser = optimiser(constraint.parameters(), ...)
predictions = model(inputs)
loss_1, loss_2 = calculate_loss(predictions)
total_loss = loss_1 + constraint(loss_2)

Now, this seems to run okay on one GPU without DistributedDataParallel, but when moving to multiple GPUs with DistributedDataParallel it errors.

I found this answer and tried what is suggested there, something along the lines of:

# in main
mp.spawn(train, nprocs=int(config.n_gpus * config.n_nodes), args=(config,))


# in train (the distributed function)
os.environ['MASTER_ADDR'] = get_ip()
os.environ['MASTER_PORT'] = str(port_nr)

dist.init_process_group(backend='nccl', init_method='env://', world_size=int(config.n_gpus * config.n_nodes), rank=device_rank)

pg_model = torch.distributed.new_group(list(range(world_size)))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device_rank], find_unused_parameters=True, process_group=pg_model)

pg_constraint = torch.distributed.new_group(list(range(world_size)))
constraint = torch.nn.parallel.DistributedDataParallel(constraint, device_ids=[device_rank], find_unused_parameters=True, process_group=pg_constraint)

The error I get is:

File "/home/cbarkhof/code-thesis/NewsVAE/", line 401, in assemble_loss
    beta_kl = self.manager["beta_KL"]["constraint"](kl_loss)
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 617, in forward
    inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 643, in scatter
    return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 36, in scatter_kwargs
    inputs = scatter(inputs, target_gpus, dim) if inputs else []
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 28, in scatter
    res = scatter_map(inputs)
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 15, in scatter_map
    return list(zip(*map(scatter_map, obj)))
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 13, in scatter_map
    return Scatter.apply(target_gpus, None, dim, obj)
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 92, in forward
    outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
  File "/home/cbarkhof/.local/lib/python3.6/site-packages/torch/nn/parallel/", line 186, in scatter
    return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
RuntimeError: chunk expects at least a 1-dimensional tensor

which errors for the part constraint(loss_2)

loss_2 is indeed is zero dimensional, but that is expected as it is just an average loss term. I guess I have something wrong in my set-up. Anyone that can point me in the right direction? That would help me a lot, thanks!



Hey @ClaartjeBarkhof, can you try wrapping your model and constraint into one nn.Module first and then wrap that module with DistributedDataParallel. Sth like:

class WrapperModule(nn.Module):
  def __init__(self, model, constraint):
    self.model = model
    self.constraint = contraint

  def forward(self, inputs):
    predictions = self.model(inputs)
    loss_1, loss_2 = calculate_loss(predictions)
    total_loss = loss_1 + constraint(loss_2)
    return total_loss

ddp = DistributedDataParallel(
  WrapperModule(model, constraints), 
model_optimiser = optimiser(model.parameters(), ...)
constraint_optimiser = optimiser(constraint.parameters(), ...)


BTW, I noticed that the code uses two different processes groups for model and constraints. When using NCCL backend, this could lead to deadlock. As of 2.8, NCCL requires using only one communicator for each CUDA device at any given time instance (i.e., not thread-safe)

1 Like

Thanks for your swift reply @mrshenli, I am going to try this! Will follow up.

Your solution works, thank you @mrshenli .

1 Like