Training Info:
GPU Device Type: A100
Number of GPU: 8
Code snippet:
process_group = dist.new_group(ranks=None, backend=“nccl”)
DPP_model.register_comm_hook(
PowerSGDState(process_group=process_group, matrix_approximation_rank=1),
batched_powerSGD_hook,
)
Reference: DDP Communication Hooks — PyTorch 1.11.0 documentation