This is probably trivial, but I saw in this example for ZeRO optimizer that after wrapping a model with DDP, the peak GPU memory almost doubles.
What is the reason? I thought DDP’s only job was synchronizing gradients wisely during the backward, and sure what job it needs to do during initialization. I also think it has some function during forward() that I’m not aware of, is that the case?
as mentioned on that page, each worker has a separate optimizer state, so that overhead adds to the extra memory usage.
The idea of ZeroRedundancyOptimizer comes from DeepSpeed/ZeRO project and Marian that shard optimizer states across distributed data-parallel processes to reduce per-process memory footprint.
Thanks @smth, but as far as I can tell, the print that shows more memory is being allocated happens before any optimizer is instantiated:
ddp_model = DDP(model, device_ids=[rank])
print_peak_memory("Max memory allocated after creating DDP", rank)
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = ZeroRedundancyOptimizer(
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)