FSDP: how to compute grad norm?

Hi all!
I am using the following code to compute grad norm in FSDP

def compute_grad_norm(model: torch.nn.Module):
    g_sqsum = 0.0
    for p in model.parameters():
        g = p.grad
        if g is not None:
            g_sqsum += (g ** 2).sum().item()
    return torch.tensor(g_sqsum, device=torch.cuda.current_device())

if __name__ == '__main__':
    dist.init_process_group("nccl")
    world_size = int(os.environ["WORLD_SIZE"])
    model = MyModel()
    model = FSDP(
        model,
        auto_wrap_policy=wrapping_policy,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        mixed_precision=bfSixteen_mixed,
        limit_all_gathers=True,
        sync_module_states=False,
        param_init_fn=None,
        use_orig_params=True,
    )
    
    # after loss backward
    grad_norm = compute_grad_norm(model)
    dist.all_reduce(grad_norm, op=dist.ReduceOp.SUM)
    grad_norm = grad_norm / world_size

I am wondering if it is correct to divide grad_norm by world_size after all_reduce in FSDP? To the best of my knowledge, in FSDP, each device only retains a portion of the parameters. Therefore, after the all_reduce operation, the total grad_norm should have already been obtained, and there is no need to divide it by world_size.