I was reviewing PyTorch’s Fully Sharded Data Parallel (FSDP) implementation and its associated paper but came across a point of confusion. It seems that after the forward pass, all unsharded parameters are discarded. However, my understanding is that we need to retain their activations to compute gradients for the unsharded parameters during the backward pass.
Are the activations for unsharded parameters still being kept somewhere? If not, how does FSDP compute gradients for unsharded parameters during backward pass on the same batch of data available on the current device?