In FSDP2 module’s post backward hook is registered through an autograd function which runs the hook when the gradients of the module’s inputs have been computed.
But it seems like the hook assumes that at the moment it is run the gradients of the module’s weights have also been computed. Is it actually guaranteed that weight gradients are ready at the moment input gradients are? If yes, how so?
I would assume all gradients were computed and are ready if you are using a post_backward hook since the hook will be called after the backward call. Let me know if I misunderstood your question.
Well the hook is registered on the input tensors, but not the weights, so it’s not clear to me that at that point the weight gradients will be also computed and accumulated.
I don’t know much about the autograd engine, but it seems to me that it might be possible for it to be in a state when some part of the backward graph is already computed (which includes input tensor nodes in this case) but another part of the graph (which includes weight nodes) is not computed yet.