Hi all!
I am using the following code to compute grad norm in FSDP
def compute_grad_norm(model: torch.nn.Module):
g_sqsum = 0.0
for p in model.parameters():
g = p.grad
if g is not None:
g_sqsum += (g ** 2).sum().item()
return torch.tensor(g_sqsum, device=torch.cuda.current_device())
if __name__ == '__main__':
dist.init_process_group("nccl")
world_size = int(os.environ["WORLD_SIZE"])
model = MyModel()
model = FSDP(
model,
auto_wrap_policy=wrapping_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=bfSixteen_mixed,
limit_all_gathers=True,
sync_module_states=False,
param_init_fn=None,
use_orig_params=True,
)
# after loss backward
grad_norm = compute_grad_norm(model)
dist.all_reduce(grad_norm, op=dist.ReduceOp.SUM)
grad_norm = grad_norm / world_size
I am wondering if it is correct to divide grad_norm
by world_size
after all_reduce
in FSDP? To the best of my knowledge, in FSDP, each device only retains a portion of the parameters. Therefore, after the all_reduce
operation, the total grad_norm
should have already been obtained, and there is no need to divide it by world_size
.