Autograd.grad with FSDP


I notice that the output of

grad = torch.autograd.grad(loss, fsdp_module.parameters())

to be different on different nodes. However, I thought PyTorch FSDP performs reduce_scatter during the backward pass, so that grad should be same on every node?

Another thing I noticed is that the length of grad is same with the full parameter number rather than the shared parameter number.

In general, I am confused about how torch.autograd.grad works with FSDP. Any clarification would be much appreciated!