Hello,
I notice that the output of
grad = torch.autograd.grad(loss, fsdp_module.parameters())
to be different on different nodes. However, I thought PyTorch FSDP performs reduce_scatter during the backward pass, so that grad
should be same on every node?
Another thing I noticed is that the length of grad
is same with the full parameter number rather than the shared parameter number.
In general, I am confused about how torch.autograd.grad
works with FSDP. Any clarification would be much appreciated!