Better understanding why AOTAutograd decomposes `fused_rms_norm_backward` for CUDA, but not for Meta tensors

I posted a series of questions on the forum today. I’ve listed them in hope that putting them together would better shed light on what I do know and don’t know.


Refer to the following code sample, tracing rms_norm and its backward pass:

import torch
import torch.nn as nn
from torch._dynamo.backends.common import aot_autograd
from torch.fx.graph_module import GraphModule

def custom_backend(gm: GraphModule, example_inputs):
    gm.print_readable()
    return gm.forward

device = 'meta'
x = torch.randn(4, 8, device=device, requires_grad=True)
rmsnorm = nn.RMSNorm(8, device=device)

def example_rmsnorm(tensor):
    return rmsnorm(tensor)

compiled_fn = torch.compile(example_rmsnorm, backend=aot_autograd(fw_compiler=custom_backend), dynamic=True, fullgraph=True)

out = compiled_fn(x)
print(compiled_fn(x))

The following is the FXGraph I get from the above code.

FXGraph Code
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[8]", primals_2: "Sym(s79)", primals_3: "f32[s79, 8]"):
         # File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
        pow_1: "f32[s79, 8]" = torch.ops.aten.pow.Tensor_Scalar(primals_3, 2)
        mean: "f32[s79, 1]" = torch.ops.aten.mean.dim(pow_1, [1], True);  pow_1 = None
        add_9: "f32[s79, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07);  mean = None
        rsqrt: "f32[s79, 1]" = torch.ops.aten.rsqrt.default(add_9);  add_9 = None
        mul_8: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(primals_3, rsqrt)
        mul_11: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_8, primals_1)
        return (mul_11, primals_1, primals_3, rsqrt, mul_8, primals_2)
        
class GraphModule(torch.nn.Module):
    def forward(self, primals_2: "Sym(s79)", primals_1: "f32[8]", primals_3: "f32[s79, 8]", rsqrt: "f32[s79, 1]", mul_8: "f32[s79, 8]", tangents_1: "f32[s79, 8]"):
         # File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
        detach: "f32[s79, 1]" = torch.ops.aten.detach.default(rsqrt)
        detach_1: "f32[s79, 1]" = torch.ops.aten.detach.default(detach);  detach = None
        mul_17: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(tangents_1, mul_8);  mul_8 = None
        mul_18: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1);  tangents_1 = primals_1 = None
        sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(mul_17, [0], True);  mul_17 = None
        view: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]);  sum_1 = None
        mul_19: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_18, primals_3)
        mul_20: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_18, rsqrt);  mul_18 = rsqrt = None
        sum_2: "f32[s79, 1]" = torch.ops.aten.sum.dim_IntList(mul_19, [1], True);  mul_19 = None
        detach_2: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_1);  detach_1 = None
        detach_3: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_2);  detach_2 = None
        mul_21: "f32[s79, 1]" = torch.ops.aten.mul.Scalar(sum_2, -0.5);  sum_2 = None
        pow_2: "f32[s79, 1]" = torch.ops.aten.pow.Tensor_Scalar(detach_3, 3);  detach_3 = None
        mul_22: "f32[s79, 1]" = torch.ops.aten.mul.Tensor(mul_21, pow_2);  mul_21 = pow_2 = None
        expand: "f32[s79, 8]" = torch.ops.aten.expand.default(mul_22, [primals_2, 8]);  mul_22 = primals_2 = None
        div: "f32[s79, 8]" = torch.ops.aten.div.Scalar(expand, 8);  expand = None
        pow_3: "f32[s79, 8]" = torch.ops.aten.pow.Tensor_Scalar(primals_3, 1.0);  primals_3 = None
        mul_23: "f32[s79, 8]" = torch.ops.aten.mul.Scalar(pow_3, 2.0);  pow_3 = None
        mul_24: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(div, mul_23);  div = mul_23 = None
        
         # File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
        add_25: "f32[s79, 8]" = torch.ops.aten.add.Tensor(mul_20, mul_24);  mul_20 = mul_24 = None
        return (view, None, add_25)

Below is the FXGraph I get with device = cuda:0.

FXGraph Code
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[8]", primals_2: "Sym(s79)", primals_3: "f32[s79, 8]"):
         # File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
        pow_1: "f32[s79, 8]" = torch.ops.aten.pow.Tensor_Scalar(primals_3, 2)
        mean: "f32[s79, 1]" = torch.ops.aten.mean.dim(pow_1, [1], True);  pow_1 = None
        add_9: "f32[s79, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07);  mean = None
        rsqrt: "f32[s79, 1]" = torch.ops.aten.rsqrt.default(add_9);  add_9 = None
        mul_6: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(primals_3, rsqrt)
        mul_9: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_6, primals_1);  mul_6 = None
        return (mul_9, primals_1, primals_3, rsqrt, primals_2)
        
class GraphModule(torch.nn.Module):
    def forward(self, primals_2: "Sym(s79)", primals_1: "f32[8]", primals_3: "f32[s79, 8]", rsqrt: "f32[s79, 1]", tangents_1: "f32[s79, 8]"):
         # File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
        detach: "f32[s79, 1]" = torch.ops.aten.detach.default(rsqrt);  rsqrt = None
        detach_1: "f32[s79, 1]" = torch.ops.aten.detach.default(detach);  detach = None
        detach_2: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_1);  detach_1 = None
        detach_3: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_2);  detach_2 = None
        _fused_rms_norm_backward = torch.ops.aten._fused_rms_norm_backward.default(tangents_1, primals_3, [8], detach_3, primals_1, [True, True]);  tangents_1 = primals_3 = detach_3 = primals_1 = None
        getitem: "f32[s79, 8]" = _fused_rms_norm_backward[0]
        getitem_1: "f32[8]" = _fused_rms_norm_backward[1];  _fused_rms_norm_backward = None
        return (getitem_1, None, getitem)

I have also attached logs I get when running with TORCH_SHOW_DISPATCH_TRACE=1 TORCH_LOGS=+dynamo,bytecode,+aot,graph,graph_code,aot_graphs

My original question was: “Why does AOTAutograd decompose the backward pass of rms_norm only for device=Meta, when the forward pass is decomposed the same across both Meta and cuda:0?”

My intuitive answer is: “When tracing the forward pass, Autograd records rms_norm differently on its internal ‘tape’ (which it replays when tracing the backward pass), even when Dynamo records the same decomposition in the forward pass FX Graph.”

At this point I am ruling out the involvement of Dispatcher and the difference between MetaBit and CUDABit. I am also ruling out the case that this is similar to this earlier forum question. (i.e. something is happening in Autograd, not in Dispatcher). This reasoning came from the fact that in the Dispatcher logs (from TORCH_SHOW_DISPATCH_TRACE=1) , fused_rms_norm_backward is not considered at all for Meta, while it is in CUDA.

Meta Log
 [call] op=[aten::mul.Tensor], key=[PythonDispatcher]
  [redispatchBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
   [redispatch] op=[aten::mul.Tensor], key=[PythonDispatcher]
    [callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
     [callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
      [call] op=[aten::empty.memory_format], key=[PythonDispatcher]
       [redispatch] op=[aten::empty.memory_format], key=[Meta]
      [call] op=[aten::detach], key=[Meta]
      [call] op=[aten::empty_strided], key=[PythonDispatcher]
       [redispatch] op=[aten::empty_strided], key=[Meta]
      [call] op=[aten::detach], key=[Meta]
     [callBoxed] op=[aten::detach], key=[Meta]
     [call] op=[aten::detach], key=[Meta]
 [call] op=[aten::mul.Tensor], key=[PythonDispatcher]
  [redispatchBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
   [redispatch] op=[aten::mul.Tensor], key=[PythonDispatcher]
    [callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
     [callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
      [call] op=[aten::empty_strided], key=[PythonDispatcher]
       [redispatch] op=[aten::empty_strided], key=[Meta]
      [call] op=[aten::detach], key=[Meta]
     [callBoxed] op=[aten::detach], key=[Meta]
     [call] op=[aten::detach], key=[Meta]
 [call] op=[aten::sum.dim_IntList], key=[PythonDispatcher]
  [redispatchBoxed] op=[aten::sum.dim_IntList], key=[PythonDispatcher]
   [redispatch] op=[aten::sum.dim_IntList], key=[PythonDispatcher]
CUDA Log
 [call] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
  [redispatchBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
   [redispatchBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
    [callBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
     [callBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
      [call] op=[aten::to.dtype], key=[PythonDispatcher]
      [call] op=[aten::to.dtype], key=[PythonDispatcher]
      [call] op=[aten::to.dtype], key=[PythonDispatcher]
      [call] op=[aten::mul.Tensor], key=[PythonDispatcher]
       [call] op=[aten::empty_strided], key=[PythonDispatcher]
        [redispatch] op=[aten::empty_strided], key=[Meta]
       [call] op=[aten::detach], key=[Meta]
      [call] op=[aten::mul.Tensor], key=[PythonDispatcher]
       [call] op=[aten::empty_strided], key=[PythonDispatcher]
        [redispatch] op=[aten::empty_strided], key=[Meta]
       [call] op=[aten::detach], key=[Meta]
      [call] op=[aten::mul.Tensor], key=[PythonDispatcher]
       [call] op=[aten::empty.memory_format], key=[PythonDispatcher]

Now here are my true questions:

  1. Is the above intuitive answer correct?

  2. Why would AOTAutograd record rms_norm differently in the tape, when the forward pass decomposition (as captured by the forward FXGraph) is the same?

  3. Are there any code pointers, where I can actually debug/print/breakpoint and observe how this is recorded differently?

The PyTorch build is a local build with PyTorch commit (9ad7dd5) dated Jul 28 (A more recent version broke due to what looks like a cuDNN version mismatch, although I’m sure I did something wrong… My fingers are crossed that nothing fundamental changed since then):