DDP Before or After Optim?

I’ve been trying to convert my trainer to use DDP when available. What I’m wondering is if there is anything wrong with the below approach. It seems to work for me and I’ve checked through it all.

The issue I’m having is that tutorial advises setting up DDP before optim, and using DDP for loading/saving, but I am finding that setting up DDP after these does work better given I want to do EMA checkpointing. My implementation is based off of OpenAI torch code.

This is covered elsewhere, but DDP will prepend module. to your variable names. This means that if my EMA state is created and saved using the DDP model it will not reload correctly when switching to non-DDP training. However, if I use the unwrapped model for loading (with sync), saving, and EMA setup, then DDP for the forward-backward pass then everything seems to work. This is against the advise in tutorials, although it’s not suggested this is explicitly wrong.

As far as I can see, so long as the DDP wrapped model is used for the forward-backward passes, the unwrapped model will have the same weights across all ranks at start and after weight updates. I’ve check this using…

def assert_equal(tensor) -> None:
    with torch.no_grad():
        gather_list = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
        dist.gather(tensor, gather_list if dist.get_rank() == 0 else None, dst=0)

    if dist.get_rank() != 0:
        dist.barrier()
        return

    unequal_indices = [i for i, t in enumerate(gather_list) if not torch.equal(tensor, t)]
    assert not unequal_indices, f"Tensors at ranks {unequal_indices} are not equal to rank 0."
    dist.barrier()

Here is a snippet from my trainers initialisation, with what appears to be working…

        self._load_params()

        decay, no_decay = get_weight_decay_parameters(self.model)
        optim_groups = [
            {"params": decay, "weight_decay": self.weight_decay},
            {"params": no_decay, "weight_decay": 0.0},
        ]

        self.opt = torch.optim.AdamW(optim_groups, lr=self.learning_rate)
        self._load_optim_state()

        self.ema_named_params = [self._setup_ema_named_params(rate) for rate in self.ema_rate]

        if dist.is_initialized():
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
            self.ddp_model = DDP(
                model,
                device_ids=[dist_util.device()],
                output_device=dist_util.device(),
                find_unused_parameters=False
            )
        else:
            self.ddp_model = self.model

        # checks for debugging, will be removed
        for p in self.model.parameters():
            dist_util.assert_equal(p)

        for p in self.ddp_model.parameters():
            dist_util.assert_equal(p)

As mentioned above, loading/saving and EMA creation are applied to the unwrapped model, and the DDP is only used to forward-backward. I’ve checked the weights after several iterations and they seem to not diverge across the ranks.

I suspect this is a valid method, but the tutorials advise using DDP for everything since it’s kind of less confusing than having to use the DDP wrapped model for some parts and the unwrapper for others.

Any advise would be really appreciated.