Supporting Autograd for Collectives

Francisco Massa Luca Wehrstedt Ke Wen Will Constable

TL;DR: We avoided defining backwards for collectives for too long, because it was not urgent enough to push through some obscure stuff and align on an implementation. Now we’re seeing more usage of collectives and missing backwards support is becoming a serious problem.

Why Has This Taken So Long?

In short, because it’s harder than it sounds and nobody prioritized it enough. Autograd support was part of the ‘goals’ in the first Functional Collective RFC, but it was never a priority because in so many cases (FSDP, DDP, PP) there is no need for collective backwards, and for DTensor users, gradients are handled by DTensor itself. Worse, the gradient formula for a given collective is context-dependent, calling into question the validity of providing a ‘default’ backwards implementation in the first place (e.g. this comment).

I hope this post will serve first to demystify the issues around collective backwards. Second, I think there are some immediate actions we can take to improve things. And Third, we may need additional discussion on larger changes (e.g. RFCs), and we can refer back to this post as a ground truth.

What Are the Correct Backwards Formulas Anyway?

Let’s start by writing out the expected formulas from first principles, avoiding the implementation detail of multiple processes or comm libraries.

Example: naive_all_reduce

We can implement all_reduce as a function taking many input tensors and returning the reduced tensor multiple times as output.

def naive_all_reduce(inputs: list[torch.Tensor]) -> list[torch.Tensor]:
   common_output = torch.sum(inputs)
   return [common_output for _ in inputs]

Because we return multiple copies of the summed output, the gradients that come back from each ‘output path’ will be summed into the .grad field of ‘common_output’, before being passed back to each of the inputs. So this means our backwards pass is essentially another sum allreduce:

def naive_all_reduce_backward(d_outputs: list[torch.tensor]) -> list[torch.Tensor]:
   accumulated_grad = torch.sum(d_outputs)
   return [accumulated_grad for _ in d_outputs]

Using this framing, the mathematical formula for each operator should be unambiguous and indeed easy to verify just by using autograd on top of local torch operations.

One thing to call out, which will become important later: we are not making assumptions about our input/output tensors being related to one another in any way! This is the default assumption in pytorch, and indeed it is what you’d get if you ran torch.autograd on naive_all_reduce. But as we will see later, it is not what we actually want for many applications in distributed training.

Collectives

Forward Backward Notes
gather scatter
scatter gather
reduce (avg, sum, premul_sum) broadcast Bitwise ops not supported for grad (band, bor, bxor)
reduce (max, min) Identity (for max/min src) Scaled (for a tie) 0 (for others)
reduce (product) fwd_out / fwd_in * dout
broadcast reduce(sum)
all_to_all all_to_all
all_reduce (avg, sum, premul_sum) all_reduce(sum) Common exception, e.g. megatron TP, see below; Bitwise ops not supported for grad (band, bor, bxor)
all_reduce (max, min) all_reduce(sum) (for max/min src) 0 (for others)
all_reduce (product) fwd_out / fwd_in * allreduce(sum, dout)
all_gather reduce_scatter(sum)
reduce_scatter all_gather
all_to_all all_to_all

Non-SPMD cases

We probably should not implement backwards automatically for non-spmd communication operations, even though the mathematical formula is trivial. This is because the ordering of operations during the backwards pass is critical to maintain correctness and avoid hangs, but we have no practical way of ensuring that we execute backwards graphs in a particular order. This can technically affect SPMD collectives too, but is particularly bad for Non-SPMD collectives.

Forward Backward
send recv
recv send

Now that we have enumerated the backwards formulas, let’s see how people actually use collectives in practice.

These ‘mathematically correct’ backwards formulas for collectives are definitely useful in some cases. At least we have seen several examples recently: Tristan Rice used some for his Context parallelism prototype (code), and Ke Wen used some for his DeepSeek implementation (code).

Exceptions That Overrule The Rule

Let’s take a look at Tensor Parallelism as described in the original Megatron paper [1909.08053]. This is just one example, but the pattern it describes is extremely common in deep learning (moreso than the examples I could list above).

Figure 3. Blocks of Transformer with Model Parallelism. f and g are conjugate. f is an identity operator in the forward pass and all reduce in the backward pass while g is an all reduce in the forward pass and identity in the backward pass.

Communication operations f and g are introduced to perform a particular pattern of communication during forward and backwards that does not match any particular operator in our table above!

The whole point of f and g is to make a semantic leap. f takes a replicated tensor and makes it non-replicated. g merges non-replicated tensors back to replicated ones.

PyTorch DTensor has a similar concept. DTensors automatically track whether a tensor is a ‘Replica’ (e.g. X input to f) or ‘Partial’ (e.g. Z1 input to g), and offers APIs like ‘redistribute()’ to convert between these layouts.

Jax uses typing to tag values as ‘device-variant’ or ‘device-invariant’ [docs], and defines a set of operators that move values back and forth [docs]. pbroadcast is equivalent to f, while psum2 is equivalent to g. Note: psum itself also exists, and is equivalent to ‘all_reduce’, and the whole reason for introducing psum2 is to solve the issues outlined above.

Figure 4. Snippet from the table of operators introduced in JAX to transform between device-invariant and device-variant states.

So What Should We Do?

To recap, we can define mathematically pure gradient formulas for our collectives, but they don’t address some of the most common use cases. And we have a higher-level concept (call it replica vs partial, or device invariant vs variant) that we can make clearer and better supported in PyTorch.

Fix and Improve Collective APIs

  1. [P0] We should patch the silent-correctness issue with functional collectives and C10d collectives by throwing an error in backwards ([C10D] Autograd Support for Collectives · Issue #152131 · pytorch/pytorch · GitHub

  2. [P1] Implement mathematically pure backwards formulas for functional collectives and C10d collectives. (Note: naive backwards impl would be synchronous.) (([C10D] Autograd Support for Collectives · Issue #152131 · pytorch/pytorch · GitHub))

  3. Deduplicate funcol, autograd_funcol, and nn.functional

  4. [P2 / Optional] Improve eager backwards performance. (Alternative: rely on compile)

Explicitly Support the Device-Varying Case

Luca Wehrstedt and Francisco Massa point out that the existing DTensor internal helper method redistribute_local_tensor is very nearly what we’d need to offer a set of manual APIs that let users convert plain torch.Tensor between logically ‘replicated’ and ‘partial/device-varying’ states.

For users that don’t want to use DTensor for their whole program, we can still leverage its internals to offer a manual API that takes care of correctly transforming tensors between device-varying and device-invariant forms. In other words, we can offer the equivalent of f and g operations (as well as others described by JAX).

However, there are a few things we might need to harden about DTensor’s semantics. We’ll have to dive a little deeper here.

  • Currently DTensor sharding rules are not as disciplined as they could be: A ‘replica’ tensor can directly be used as input to a matmul with a ‘shard’ input, while in a stricter sense we should force DTensor to convert the replica to a ‘partial’ first
  • ‘redistribute_local_tensor’ has some questionable logic about skipping partial to replica conversion during backwards

I plan to look further into the explicit API and DTensor case, and update RFCs on pytorch accordingly. Hopefully this note clarifies the topic for folks, and if there is still something unclear or a mistake in the reasoning, please let me know!

3 Likes

Thanks for writing this down!

Currently DTensor sharding rules are not as disciplined as they could be: A ‘replica’ tensor can directly be used as input to a matmul with a ‘shard’ input, while in a stricter sense we should force DTensor to convert the replica to a ‘partial’ first

Can you elaborate this more? I don’t understand the statement here. There’s no partial input should be accepted to a matmul.

‘redistribute_local_tensor’ has some questionable logic about skipping partial to replica conversion during backwards

Hmm what is questionable here?

I feel supporting collective autograd formulas is a good direction and could help dedupe and remove the undocumented nn.functional collectives. There’s a concern about backward compatibilty (i.e. how to actually dedup those implementations, are autograd version of collectives stay different API to avoid BC breaking). I feel a concrete proposal on the details would be more helpful :slight_smile:

But I think it is questionable whether to expose some manual f/g operations, as the f/g operations actually embedding the concepts of sharding implicitly without any explicit sharding annotations. There is not something general enough that could be exposed, because the input shardings would be different for case to case. i.e. redistribute_local_tensor needs user to explicitly pass src/dst sharding spec in order to deduce the correct operation. Also this API does not implement autograd at all, and if you implement autograd, this is basically the redistribute api DTensor had today

Honestly for users who don’t want to use DTensor in the whole program, there’s already a way for them to construct a DTensor via from_local, do redistribute, and convert back to local tensor. Each operation is very explicitly and general already, so I don’t see a need to introduce some other API to achieve the same thing.

IMO it’s better to separate the DTensor discussion to the autograd collectives posted here as I think they are cleanly two separate layers of abstractions.

Totally agree that this is ‘implicit’ and that makes it dangerous. Point of fact, lots of users do it anyway. I think its worth offering an API to do it and document what it means. But I also think its cleaner and safer to just use DTensor. We should recommend this path for this reason.

Can you elaborate this more? I don’t understand the statement here. There’s no partial input should be accepted to a matmul.

Yea I need to spend more time looking into this part. I will try to write something further once I do.

. redistribute_local_tensor needs user to explicitly pass src/dst sharding spec in order to deduce the correct operation.

This part seems fine to me. Either (a) the API is that users pass src/dst sharding spece manualy, or (b) we offer apis like ‘psum2’ in jax, which desugars into this. Either way, its a way to manipulate these states even though there is no metadata that makes sure you don’t mess up.

Also this API does not implement autograd at all, and if you implement autograd, this is basically the redistribute api DTensor had today

Yea, maybe what I actually want is to use redistribute but then remove the DTensor wrapper immediately. I definitely want autograd supported inside this API. The idea would be that users can do the ‘replicate → partial’ transformation explicitly during forward (which is really a no-op), and then they would get an all-reduce ‘for free’ during backwards.

1 Like

A couple thoughts:

Worse, the gradient formula for a given collective is context-dependent, calling into question the validity of providing a ‘default’ backwards implementation in the first place (e.g. this comment).

I was talking with Yifu, and in some sense, I agree with him that there is indeed a single “correct” backwards implementation of all-gather (or any other collective). The argument is similar to the one given with naive_all_reduce - namely, if you view every output of the collective as just another usage of the tensor, then the backwards pass should clearly just be the sum of the gradients of every output!

For example, say we had

x1, x2, x3, x4 = y
loss = f(x1, x2, x3, x4)

clearly, gradX = sum(gradX1, gradX2, gradX3, gradX4)!

However, if we were to talk about this from the perspective of DTensor, what are we saying about the layouts of gradX1/gradX2/...? We are saying that these are all partial sums of gradX.

i.e. when I write this and it runs on 4 ranks

y = allreduce(x)

in reality, I actually have y_0, y_1, y_2, y_3.

So, using DTensor vocab, we are guaranteeing that the gradY above is Partial. In other words, from the autograd engine’s perspective, gradY only contains the loss gradient wrt y_0 and not y_1/y_2/y_3. However, in DTensor, there is another layout it could be - Replicate. And if it were replicated, then that would mean that gradY_0 contains sum(gradY_i), and so if the backwards pass were an allreduce that would be wrong.

So, if we have our differentiable collectives, can we guarantee that we’ll never run into the Replicate case?

Actually, yes! It is consistent for us to say that on every single device, gradX always represents the gradient contribution of X to the final loss. Thus, the proposal above is sufficient - allgather ↔ reducescatter, allreduce ↔ allreduce, and we’re good!

However, there were a couple of inconsistencies here that made me feel like this wasn’t a complete picture. For example:

  1. If there is only one correct backwards for a collective, why does the original TP formulation Megatron-LM clearly not use those backwards? (i.e. in figure 3, g is an allreduce, but its backward is the identity!). The post says that these are for a semantic leap, but even so, this seems like a weird contradiction to me. The post writes “the mathematical formula for each operator should be unambiguous” but then for Megatron, the backwards pass of an allreduce is the identity??? How can it both be unambiguous but also variable?
  2. Can we implement DDP with these autogradable collectives? I mean, the backwards of DDP is an allreduce, but there’s no corresponding allreduce in the forwards? That’s kinda weird.

After thinking about it, my conclusion is that: The original sin is that we often want parameters (leaf nodes!) to be replicated across machines. All other inconsistencies flow downstream from this desire.

Basically, this does not make sense from the autograd perspective. Multiple leaf nodes that actually represent the same value? A replicated parameter doesn’t make any sense - the “morally” correct thing to do is to have it live on rank 0 and then be broadcasted out to every rank if we want each rank to use it, which then makes the parameter “sharded”, but it doesn’t matter because the backwards of a broadcast is a “reduce”, which gives us our desired semantics.

So, really, when we have a replicated parameter, before using it, we are essentially inserting a “fake_broadcast(Replicated) → Sharded”. The backwards pass of this is a fake_reduce (implemented with an allreduce). This answers question 2 - if we want to implement DDP then we need to support replicated parameters, and in order to support replicated parameters we need a fake_broadcast, which morally acts as if it broadcasts all values from rank 0, but in reality is a no-op.

Now, let’s look at question 1 - why do the collectives in Megatron-LM’s collectives not follow the given autograd rules?

First of all, let’s imagine implementing Megatron-LM using the autogradable collectives listed above. To do so, we would skip f, and g’s backwards would be another allreduce. Kinda like this diagram

Ignoring anything outside of the block (like norms) for a second, note that this is totally correct! The transformer block goes back to back, and so if we remove f from the equation, it immediately goes into the previous transformer block’s g, thus performing the allreduce that would have been done by f! So… if they could use these “autogradable” collectives, why did Megatron-LM formulate it differently? And what’s the difference?

As it turns out, there is a difference! And it all comes back around to the original sin - replicated parameters. Inbetween the transformer blocks we also have layernorms, which have parameters. In Megatron-LM, they replicate the layernorm parameters.

So, let’s use the above view that “replicated” actually means “lives on rank 0”. So, in our scheme, what would we do?

To write it out, in our scheme using autogradable collectives, we would have:

X: Sharded # beginning
.... # bunch of stuff
Z: Sharded = all_reduce(X) # g forwards
norm_weight: Rank0
broadcast_norm_weight: Sharded = fake_broadcast(norm_weight) # no-op
Z_normed: Sharded = layer_norm(Z: Sharded, broadcast_norm: Sharded)
X2: Sharded = Z_normed # starts next layer

and thus this in the backwards

grad_X2: Sharded
grad_Z_normed: Sharded = grad_X2
grad_broadcast_norm_weight: Sharded = layer_norm_backward(grad_Z_normed: Sharded)
grad_norm_weight: Rank0 = fake_reduce(grad_broadcast_norm_weight) # all-reduce
gradX: Sharded = all_reduce(grad_Z) # g backwards

Note that here, we need to allreduce norm_weight, since its gradient computation is sharded.

However, the Megatron-LM strategy looks like this:

X: Sharded # beginning
.... # bunch of stuff
Z: Rank0 = fake_reduce(X) # g forwards
norm_weight: Rank0
Z_normed: Rank0 = layer_norm(Z: Rank0, broadcast_norm: Rank0)
X2: Sharded = fake_broadcast(Z_normed)  # f forwards

with the corresponding backwards being

grad_X2: Sharded
grad_Z_normed: Rank0 = fake_reduce(grad_X2)  # f backwards
grad_norm_weight: Rank0 = layer_norm_backward(grad_Z_normed: Rank0)
gradX: Sharded = broadcast(grad_Z) # g backwards

Basically, because norm_weight is replicated (i.e. on “Rank0”), there “should” be an allreduce in the backwards pass. However, if you ensure that all computation performed with it is also replicated (i.e. on “Rank0”), then you can avoid the fake_broadcast, and thus avoid the allreduce.

Note: The general insights here bears similarities to Jax’s page on “device-varying” vs “device-invariant” values (although on their page it’s kind of a weird special case for the interactions between shmap output and pjit). In their terminology, fake_broadcast == pbroadcast and goes from “invariant” to “device-varying”. However, because shmap is fundamentally a function wrapper, their doc doesn’t touch on “replicated parameters”. I also prefer the conceptual framing that they morally only live on one device (i.e. rank0).

Conclusion

This was way too long haha. Arguably, this is a lot of words to not really disagree with anything in the original post. But I hope the explanation of why the “Megatron-LM exception” is not really an exception is useful.

But, my overall thoughts on the plans/what to do next.

  1. I think it makes sense to provide the “autograd-able collectives” with the defined autograd formulas, as suggested above. However, we should emphasize that these assume that no inputs are “replicated”/“device-invariant”/“morally live on rank0”. For example, if you were to implement the megatron-style TP without the additional fake_broadcast, we would end up with different gradients for every rank’s norm_weight, essentially treating that as if it were sharded and not replicated.

  2. I think we should consider adding the collective pairs that map from “replicated” to “sharded” and vice-versa. Specifically, I think there’s “fake_broadcast” (broadcasts from rank0, rank0 → sharded, no-op ↔ allreduce), “fake_reduce” (reduces down to rank 0, sharded → rank0, allreduce ↔ noop), “fake_scatter” (scatters from rank0, rank0 → sharded, split ↔ all-gather), and “fake_gather” (gathers to rank0, sharded → rank0, all-gather ↔ split). We wouldn’t be able to guarantee that users are necessarily using them correctly, but I think this would allow users to basically implement most fundamental collective patterns via autograd, albeit not necessarily efficiently (we could also add a “slow” mode that does a check to ensure that “replicated” values are actually replicated).

  3. I think we need to think carefully about DTensor (i.e. global SPMD) interops with default PyTorch (i.e. local SPMD).

2 Likes

maybe what I actually want is to use redistribute but then remove the DTensor wrapper immediately. I definitely want autograd supported inside this API. The idea would be that users can do the ‘replicate → partial’ transformation explicitly during forward (which is really a no-op)

Hmmm but “replica → partial” transformation is not really a no-op in the forward…

It seems like you mainly want to expose some function that achieves the f/g operation in the megatron-lm style, and you want to reuse some of the DTensor components. IMO exposing sth like those with the dtensor concepts are not really easy and might bring more confusions to users (i.e. even you are confused in those cases).

I think the underlying problem here is: DTensor Does NOT have a one-to-one mapping to the Megatron f/g operations. The way that DTensor performs megatron TP achieves “functionally equivalence” with f/g, but DTensor does not simply formalizes the execution as two autograd function, let’s take the first column parallel linear as an example:

  • Megatron f autograd function:
    • forward → no-op
    • backward → allreduce
  • DTensor:
    • forward → replicated input @ column sharded weight: no communication
    • backward → columned sharded grad weight @ row sharded grad in: partial grad out, need allreduce

There is no single autograd function that can encapsulate the shardings rationally without explicitly specifying all of those shardings, specifically if you look at the details, why in the DTensor case the forward is no-op? It is because the operation it’s performing is a matmul and the input have the right shardings, so the no-op ties specifically to the matmul operation! Forcing it to be a single autograd function by specifying all of those shardings does not make sense, as you don’t know the following compute operation to know whether it should be a no-op or not. If we look closer into the MegatronLM paper, it actually formalizes the matmul sharding first then introduce those autograd functions, which also means f/g tights to the specific matmul sharding strategy.

This is why I think DTensor and collectives are two clearly separated level of abstractions, so I still don’t think the “redistribute but then remove the DTensor wrapper” is the right thing to do, at least not with the dtensor concepts. :slight_smile:

Great insights about the implications of replicated parameters! One minor correction maybe:
the last line of megatron strategy backwards should be

gradX: Sharded = fake_broadcast(grad_Z) # g backwards

It looks to me that the fake_broadcast and fake_reduce basically mimics the f/g operation, but it make things clear about why autograd behaves in that way using the per-device semantic (or in your term, local SPMD) mode, this is nice! One thing I wonder is that how to rationalize the “fake” in a way user could understand and whether/how to deal with rank0 broadcast/reduction as they could necessary bring down the efficiency?

I also agree that we should be careful about the behaviors between DTensor (implements “global SPMD”) and manual collectives with PyTorch (local SPMD). I do think as a higher level abstraction like DTensor, it hides certain details so exposing collective layer is important to users.

For distributed algorithms with PyTorch, what we have seen so far is mostly a state between a global SPMD DTensor and local SPMD autograd collectives: users actually implemented “global SPMD algorithms” using the PyTorch manual collectives (like Megatron TP). So for the next steps we are discussing, I think we should evaluate it together and answer the following questions:

  • Which option is better? I personally perfer the third option
    • implement the proposed autograd formula for the current collectives (i.e. allreduce) by default and document the per-device semantics clearly
    • offer those autograd enabled collectives as separate APIs, still cleanly document the behaviors so that user can choose to use them as needed
    • keep the current state (don’t implement the autograd formulas for those collectives), but rather provide collective pairs like jax pbroadcast (the second proposal @Chillee mentioned)
    • Do two or all of the above three
  • Whether/How to encode sharded/replicated concepts to any of the proposal above? IMO directly exposing dtensor concepts might be confusing to users as I replied above.

sharing some of my thoughts as a note here

Autogradable all-reduce

For simplicity of discussion, I’m restricting the scope of collectives in this note to all-reduce only. Other collectives should follow similar reasoning.

In general, the backward of all-reduce is still all-reduce, but the bwd reduction op depends. Recall that the definition of gradient of an op is about “if input changes delta, how much output would change?”.

Let’s decompose all-reduce into mathematically equivalent reduce → broadcast, where the reduction op for reduce could be e.g. SUM, AVG, etc.

  • The backward of reduce is
    • IDENTITY if the reduction op is SUM.
    • “MULTIPLY 1/world_size” if the reduction op if AVG.
  • The backward of broadcast depends on how the fwd output is consumed / what loss function we use. E.g.
    • If we do .sum().backward() on the output of broadcast, the bwd should be SUM, or equivalently “MULTIPLY world_size”.

Tensor Parallel

In a world where the only existing autogradable comm op is the general all-reduce, i.e. all-reduce(fwd) + all-reduce(bwd), assuming some flexibility in choosing the fwd/bwd reduction ops, we would perform TP as follows.

  • fwd: → all-reduce with AVG (f) → TP → all-reduce with SUM (g) → Norm →
  • bwd: ← all-reduce with SUM (f) ← TP ← all-reduce with AVG (g) ← Norm ←

The reason we need op g is because we need the activation to be Replicate (from Partial, in DTensor terms) to perform (Layer)Norm, not because Norm parameters are Replicate (cc Horace). Likewise, the reason we need op f is because we need the gradient to be Replicate (from Partial).

In Megatron TP, the forward of f and backward of g are no-ops. In fact, they essentially are. Because of the replicated computation of Norm, the occurrences of “all-reduce with AVG” can be omitted. Therefore, from such perspective, the f and g in Megatron TP are special cases of general autogradable all-reduce collectives. The specialty comes from the device-invariant region, without which there is no correctness guarantee for the no-op parts of Megatron-style f and g. In other words, the autogradable collectives of f and g are only valid with context / metadata.

How DTensor does it

As one can see, we need some sort of special treatment to achieve Megatron TP, to avoid the wasteful general autogradable all-reduce. In the case of DTensor, it’s NOT done via autogradable ops; it is done via a metadata-enhanced tensor subclass. In DTensor, communications always happen as a means to change placement (an annotation of global distribution of data), e.g. Partial → Replicate, whether they are explicit via DTensor.redistribute, or implicit via sharding propagation strategies/rules. As Wanchao mentioned in Supporting Autograd for Collectives - #5 by wanchaol, this is functionally equivalent to the Megatron-style f and g, but with a different “philosophy”.

To summarize

DTensor cares primarily about data placement, and derives what collectives to perform based on source and target placements. Users don’t have explicit control over what collectives would be performed, and users probably don’t even care about them. DTensor engine should make it efficient, not the users themselves.

Autogradable collectives, on the other hand, should focus on performing the desired collectives, ideally with correctness guarantee. E.g., whenever we use Megatron-style f in fwd, we should make sure it enters a device-invariant state. Depending on what users are used to, we may / may not need to support the derivations of correct DTensor placements after an arbitrary collective op.

In a nutshell, I think both are important

  • DTensor provides ergonomic ways to deal with communication details, which could be particularly useful for users who are not experts in distributed.
  • Autogradable collectives provide low-level control which DTensor lacks support (and hopefully also correctness guarantee).

But given how different and complementary they are, we probably should provide different levels of abstractions. I haven’t thought much on the interaction between the two approaches – I tentatively think they should be used disjointly.