Is it legitimate to optimize only the module part of DDP?

Taking this code snippet as an example:

model = Model()
ddp_model = DDP(model)

optimizer = optim.Adam(model.parameters())  # instead of ddp_model.parameters()

output = ddp_model(...)   # forward call
loss_fn(output).backward()  # backward

optimizer.step()  # This runs on `model.parameters()` (instead of `ddp_model.parameters()`). Is it problematic?

I tried this distributed training setting (meaning: only using DDP to do initialization, forward call while keeping the other parts only access model), and didn’t find any issue so far. However this is not following the routine as in the official tutorial.

I read from somewhere that model.forward() should never be called after the ddp_model is constructed (cannot find the reference), which makes sense as the DDP forward call takes care of certain ops needed for synchronization in the backward step. But how about the other parts? E.g. can I call the optimizer.step() above, as well as model.train() and model.zero_grad()? I know that model.state_dict() is still valid and id(model) == id(ddp_model.module) so there are probably many valid things to be called. Hence the question: where’s the boundary? Are there anything else “forbidden” like model.forward()?

Use case: I usually wrap the optimizer as well as many other helper/utility functions as class members of the Model itself for the ease of model version controls. So I’m wondering what are the restrictions of these helper members (e.g. init an optimizer member by optim.Adam(self.parameters(), ...) ) in the context of DDP.

Generally, I would of course stick to the example usage, as it would be used in other code bases and tested. I don’t have an example which would show a breaking behavior, but I also don’t know how all bells and whistles of DDP would interact with your approach. In particular, I don’t know if the “native mixed-precision”, fused optimizers, ddp_comm_hooks, etc. could break.

1 Like

Got it. Thank you for your reply! Then is it fair to say that: other than the “mutable” operations, including:

  1. ddp_model.train()/.eval()
  2. ddp_model.forward() (along with the subsequent backward from the output of the forward call + some loss function)
  3. optimizer.step() where optimizer is constructed from ddp_model.parameters()

the other operations, esp. the ones that are essentially treating the original model as a const object, are still free to use without breaking DDP’s training framework? E.g.:

  1. model.state_dict() (or ddp_model.module.state_dict(), this should be the most common one)
  2. some “const” member function from the Model class e.g. some custom evaluation method called after ddp_model.eval() and with torch.no_grad() context?

These statements appears true per my understanding of the DDP framework, but I think we need more clarity, in case that something is broken unintentionally. Custom Model class methods are quite common and I need to make myself clear of what can/cannot be called with the DDP framework.

There are still some caveats:

  • You should never call the .forward method of any model manually, but instead the model itself (via model(input)) which calls the __call__ method internally. This will make sure all hooks are properly fired before the forward is executed. Calling model.forward(input) could easily break it (standalone and DDP).
  • The state_dict of the DDP vs. standalone model will differ as the former will contain the .module keys. It’s generally a good practice to save the ddp_model.module.state_dict().
  • model.eval(), torch.no_grad() etc. will work.

Understood. Thank you!