How to pass multiple inputs to forward() in such a way that DistributedDataParallel still hooks up to all of them?

I have a model where my forward() has multiple tensors as inputs. I have the model wrapped in DDP and to do this I created an input dictionary with all my input tensors in it. I know very clever.

My problem is that DDP only syncs the first weight matrix of the model so this approach is probably completely wrong.

How do I make DDP work with multiple input tensors?

EDIT: I tried removing the dictionary entirely and passing multiple tensors to forward() and again, only the first parameter is synced.

EDIT 1: As a temporary work around I am using

                    for param in simulator.parameters():
                        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
                        param.data /= dist.get_world_size()

But this is obviously not ideal because for one, I have all_reduce called twice, once by DDP and once by me.

Iā€™m not sure how the input arguments to the forward method are related to the parameter sync issue, but DDP should synchronize the entire state_dict at the beginning.
Are all parameters properly registered and returned by model.state_dict() before you wrap it into DDP?

1 Like

So the input arguments in the forward method do not really make a difference. That was a wrong assumption by me.

But my problem remains, that DDP only syncs one weight tensor, but model.state_dict() shows the same tensors listed before and after wrapping in DDP. So I do not understand why the others are not synced. Is there a way to check which tensors have a DDP all_reduce hook attached to them?

Could you share a minimal, executable code snippet showing this behavior, please?