Supporting Autograd for Collectives

maybe what I actually want is to use redistribute but then remove the DTensor wrapper immediately. I definitely want autograd supported inside this API. The idea would be that users can do the ‘replicate → partial’ transformation explicitly during forward (which is really a no-op)

Hmmm but “replica → partial” transformation is not really a no-op in the forward…

It seems like you mainly want to expose some function that achieves the f/g operation in the megatron-lm style, and you want to reuse some of the DTensor components. IMO exposing sth like those with the dtensor concepts are not really easy and might bring more confusions to users (i.e. even you are confused in those cases).

I think the underlying problem here is: DTensor Does NOT have a one-to-one mapping to the Megatron f/g operations. The way that DTensor performs megatron TP achieves “functionally equivalence” with f/g, but DTensor does not simply formalizes the execution as two autograd function, let’s take the first column parallel linear as an example:

  • Megatron f autograd function:
    • forward → no-op
    • backward → allreduce
  • DTensor:
    • forward → replicated input @ column sharded weight: no communication
    • backward → columned sharded grad weight @ row sharded grad in: partial grad out, need allreduce

There is no single autograd function that can encapsulate the shardings rationally without explicitly specifying all of those shardings, specifically if you look at the details, why in the DTensor case the forward is no-op? It is because the operation it’s performing is a matmul and the input have the right shardings, so the no-op ties specifically to the matmul operation! Forcing it to be a single autograd function by specifying all of those shardings does not make sense, as you don’t know the following compute operation to know whether it should be a no-op or not. If we look closer into the MegatronLM paper, it actually formalizes the matmul sharding first then introduce those autograd functions, which also means f/g tights to the specific matmul sharding strategy.

This is why I think DTensor and collectives are two clearly separated level of abstractions, so I still don’t think the “redistribute but then remove the DTensor wrapper” is the right thing to do, at least not with the dtensor concepts. :slight_smile: