I have a model where my forward() has multiple tensors as inputs. I have the model wrapped in DDP and to do this I created an input dictionary with all my input tensors in it. I know very clever.
My problem is that DDP only syncs the first weight matrix of the model so this approach is probably completely wrong.
How do I make DDP work with multiple input tensors?
EDIT: I tried removing the dictionary entirely and passing multiple tensors to forward() and again, only the first parameter is synced.
EDIT 1: As a temporary work around I am using
for param in simulator.parameters():
dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
param.data /= dist.get_world_size()
But this is obviously not ideal because for one, I have all_reduce called twice, once by DDP and once by me.