How to pass multiple inputs to forward() in such a way that DistributedDataParallel still hooks up to all of them?

I have a model where my forward() has multiple tensors as inputs. I have the model wrapped in DDP and to do this I created an input dictionary with all my input tensors in it. I know very clever.

My problem is that DDP only syncs the first weight matrix of the model so this approach is probably completely wrong.

How do I make DDP work with multiple input tensors?

EDIT: I tried removing the dictionary entirely and passing multiple tensors to forward() and again, only the first parameter is synced.

EDIT 1: As a temporary work around I am using

                    for param in simulator.parameters():
                        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
                        param.data /= dist.get_world_size()

But this is obviously not ideal because for one, I have all_reduce called twice, once by DDP and once by me.

I’m not sure how the input arguments to the forward method are related to the parameter sync issue, but DDP should synchronize the entire state_dict at the beginning.
Are all parameters properly registered and returned by model.state_dict() before you wrap it into DDP?

1 Like

So the input arguments in the forward method do not really make a difference. That was a wrong assumption by me.

But my problem remains, that DDP only syncs one weight tensor, but model.state_dict() shows the same tensors listed before and after wrapping in DDP. So I do not understand why the others are not synced. Is there a way to check which tensors have a DDP all_reduce hook attached to them?

Could you share a minimal, executable code snippet showing this behavior, please?