FSDP all-gather during backward pass

Why does FSDP need to issue an all-gather during the backward pass? I’m clear on the need for an all-gather during the forward pass to materialize the FSDP unit. I’m clear on the need to reduce-scatter gradients during the backward pass so that each GPU has the averaged shard of the gradient that it owns.

1 Like

A main goal of FSDP is to reduce the memory consumption of a model - to train a larger model than could otherwise fit on the GPU.

The way it does this is by grouping the model’s layers into ‘units’, then sharding the storage of a unit across the GPUs. This way, each GPU only has to store 1/N of the model weights.

However when it comes to actually running the forward step for a unit, its layers’ weights must all be present- this is the reason for the all-gather during forward. Well, during backward for that same layer, the full weights are again needed. The backwards computation for dInput requires the weight matrix. We could just keep the weights in memory after running the forward all-gather, but then we’d defeat the memory savings we were using FSDP for in the first place.

1 Like

Thanks for your response. The question I’ve been grappling with is: in computing the gradients for a particular layer, do we really need the parameters? I’m looking for specific scenarios where backpropagation actually requires the layers’ parameters. Many neural net functions are of the form f(W, x) = Wx, where W might be a weight matrix and x would be the input data or activations. df/dW = x therefore the gradient calculation may rarely actually require the parameters for a particular layers’ gradient calculations. Please correct me if I’m forgetting something obvious about calculating gradients and there is indeed a dependency on the parameters when computing df/dW. Surely some functions will require the parameters (e.g. non-linear functions, etc.), but I imagine many won’t.

EDIT: I think this illustrates my point, at least for simple functions: Extending PyTorch — PyTorch 2.0 documentation
We want df/dW, which is implemented as:
grad_weight = grad_output.t().mm(input)

so if I’m reading this correctly, we’re taking our upstream gradients, transposing them, and matrix multiplying them with our cached input. Since we don’t care about the gradient with respect to the input (i.e. data or activations), we don’t strictly need grad_input = grad_output.mm(weight).

This example is also triggering another question I have, which is, if we cache our parameters for the backward pass (see ctx.save_for_backward(input, weight, bias)), then how will we accommodate all this extra memory usage? After all, FSDP is all about reducing the memory footprint of the system and this seems like its holding onto an entire copy of the layers’ parameters as part of this save for backward context.

I’d greatly appreciate anyone weighing in here to resolve how the backward pass occurs in FSDP and why the all-gather is strictly necessary.

EDIT 2: I think I’ve figured out where I was going wrong with my reasoning. For multilayer networks, we do care about the gradients with respect to the inputs (i.e. data or activations), and I had to remind myself how the chain rule gets applied. Take, for example, a 2 layer neural network with two simple linear layers. We have:
y = f(W_1, x) = W_1x # layer 1
z = g(W_2, y) = W_2y # layer 2
L = L(z) # loss

My reasoning was dL/dW_2 = y, so why do I care about all-gathering W_2? What I was forgetting was that dL/dW_1 = (dL/dz)(dz/dy)(dy/dW_1). As we saw above (grad_input = grad_output.mm(weight)), calculating dz/dy indeed requires all-gathering W_2 because dz/dy = W_2.

Again, I would greatly appreciate it if anyone could confirm this reasoning! I still have the question about ctx.save_for_backward but I trust PyTorch isn’t keeping around a copy of the weights in activation memory for use during backprop otherwise memory usage would blow up.


yes, you jumped to the point. For df/dW you could skip the weights. But for all but the first layer you’d need the weights so you can compute df/dInput.

Additionally, in some cases people do ‘activation checkpointing’ a.k.a. recomputation. Like FSDP, the goal is to save memory by discarding activation tensors made during forwards, and recomputing those activations during backwards. In the case where you just wanted to compute df/dW you’d still need ‘x’ (the input activations from previous layer) - in some cases, you might have thrown that away and need to recompute x = prev_layer(prev_x) during backward.

1 Like