I’m looking at the DDP source code train()
function (reference):
def train(self, mode=True):
super().train(mode)
if self._use_replicated_tensor_module:
self._replicated_tensor_module.train(mode) # type: ignore[union-attr]
return self
What is the _replicated_tensor_module
for?
Context: I’m trying to understand the difference between ddp_model.train()
and ddp_model.module.train()
.