I posted a series of questions on the forum today. I’ve listed them in hope that putting them together would better shed light on what I do know and don’t know.
- In Dynamo+AOTAutograd, why run Faketensor through the code multiple times?
- In resolving DispatchKeySet, is PythonDispatch called first or last?
- Better understanding why AOTAutograd decomposes
fused_rms_norm_backward
for CUDA, but not for Meta tensors (this post)
Refer to the following code sample, tracing rms_norm
and its backward pass:
import torch
import torch.nn as nn
from torch._dynamo.backends.common import aot_autograd
from torch.fx.graph_module import GraphModule
def custom_backend(gm: GraphModule, example_inputs):
gm.print_readable()
return gm.forward
device = 'meta'
x = torch.randn(4, 8, device=device, requires_grad=True)
rmsnorm = nn.RMSNorm(8, device=device)
def example_rmsnorm(tensor):
return rmsnorm(tensor)
compiled_fn = torch.compile(example_rmsnorm, backend=aot_autograd(fw_compiler=custom_backend), dynamic=True, fullgraph=True)
out = compiled_fn(x)
print(compiled_fn(x))
The following is the FXGraph I get from the above code.
FXGraph Code
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8]", primals_2: "Sym(s79)", primals_3: "f32[s79, 8]"):
# File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
pow_1: "f32[s79, 8]" = torch.ops.aten.pow.Tensor_Scalar(primals_3, 2)
mean: "f32[s79, 1]" = torch.ops.aten.mean.dim(pow_1, [1], True); pow_1 = None
add_9: "f32[s79, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07); mean = None
rsqrt: "f32[s79, 1]" = torch.ops.aten.rsqrt.default(add_9); add_9 = None
mul_8: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(primals_3, rsqrt)
mul_11: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_8, primals_1)
return (mul_11, primals_1, primals_3, rsqrt, mul_8, primals_2)
class GraphModule(torch.nn.Module):
def forward(self, primals_2: "Sym(s79)", primals_1: "f32[8]", primals_3: "f32[s79, 8]", rsqrt: "f32[s79, 1]", mul_8: "f32[s79, 8]", tangents_1: "f32[s79, 8]"):
# File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
detach: "f32[s79, 1]" = torch.ops.aten.detach.default(rsqrt)
detach_1: "f32[s79, 1]" = torch.ops.aten.detach.default(detach); detach = None
mul_17: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(tangents_1, mul_8); mul_8 = None
mul_18: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(mul_17, [0], True); mul_17 = None
view: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]); sum_1 = None
mul_19: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_18, primals_3)
mul_20: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_18, rsqrt); mul_18 = rsqrt = None
sum_2: "f32[s79, 1]" = torch.ops.aten.sum.dim_IntList(mul_19, [1], True); mul_19 = None
detach_2: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
mul_21: "f32[s79, 1]" = torch.ops.aten.mul.Scalar(sum_2, -0.5); sum_2 = None
pow_2: "f32[s79, 1]" = torch.ops.aten.pow.Tensor_Scalar(detach_3, 3); detach_3 = None
mul_22: "f32[s79, 1]" = torch.ops.aten.mul.Tensor(mul_21, pow_2); mul_21 = pow_2 = None
expand: "f32[s79, 8]" = torch.ops.aten.expand.default(mul_22, [primals_2, 8]); mul_22 = primals_2 = None
div: "f32[s79, 8]" = torch.ops.aten.div.Scalar(expand, 8); expand = None
pow_3: "f32[s79, 8]" = torch.ops.aten.pow.Tensor_Scalar(primals_3, 1.0); primals_3 = None
mul_23: "f32[s79, 8]" = torch.ops.aten.mul.Scalar(pow_3, 2.0); pow_3 = None
mul_24: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(div, mul_23); div = mul_23 = None
# File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
add_25: "f32[s79, 8]" = torch.ops.aten.add.Tensor(mul_20, mul_24); mul_20 = mul_24 = None
return (view, None, add_25)
Below is the FXGraph I get with device = cuda:0
.
FXGraph Code
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8]", primals_2: "Sym(s79)", primals_3: "f32[s79, 8]"):
# File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
pow_1: "f32[s79, 8]" = torch.ops.aten.pow.Tensor_Scalar(primals_3, 2)
mean: "f32[s79, 1]" = torch.ops.aten.mean.dim(pow_1, [1], True); pow_1 = None
add_9: "f32[s79, 1]" = torch.ops.aten.add.Scalar(mean, 1.1920928955078125e-07); mean = None
rsqrt: "f32[s79, 1]" = torch.ops.aten.rsqrt.default(add_9); add_9 = None
mul_6: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(primals_3, rsqrt)
mul_9: "f32[s79, 8]" = torch.ops.aten.mul.Tensor(mul_6, primals_1); mul_6 = None
return (mul_9, primals_1, primals_3, rsqrt, primals_2)
class GraphModule(torch.nn.Module):
def forward(self, primals_2: "Sym(s79)", primals_1: "f32[8]", primals_3: "f32[s79, 8]", rsqrt: "f32[s79, 1]", tangents_1: "f32[s79, 8]"):
# File: /workspace/native_pytorch/lib/python3.12/site-packages/torch/nn/functional.py:2920 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
detach: "f32[s79, 1]" = torch.ops.aten.detach.default(rsqrt); rsqrt = None
detach_1: "f32[s79, 1]" = torch.ops.aten.detach.default(detach); detach = None
detach_2: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
detach_3: "f32[s79, 1]" = torch.ops.aten.detach.default(detach_2); detach_2 = None
_fused_rms_norm_backward = torch.ops.aten._fused_rms_norm_backward.default(tangents_1, primals_3, [8], detach_3, primals_1, [True, True]); tangents_1 = primals_3 = detach_3 = primals_1 = None
getitem: "f32[s79, 8]" = _fused_rms_norm_backward[0]
getitem_1: "f32[8]" = _fused_rms_norm_backward[1]; _fused_rms_norm_backward = None
return (getitem_1, None, getitem)
I have also attached logs I get when running with TORCH_SHOW_DISPATCH_TRACE=1 TORCH_LOGS=+dynamo,bytecode,+aot,graph,graph_code,aot_graphs
My original question was: “Why does AOTAutograd decompose the backward pass of rms_norm only for device=Meta
, when the forward pass is decomposed the same across both Meta
and cuda:0
?”
My intuitive answer is: “When tracing the forward pass, Autograd records rms_norm
differently on its internal ‘tape’ (which it replays when tracing the backward pass), even when Dynamo records the same decomposition in the forward pass FX Graph.”
At this point I am ruling out the involvement of Dispatcher and the difference between MetaBit
and CUDABit
. I am also ruling out the case that this is similar to this earlier forum question. (i.e. something is happening in Autograd, not in Dispatcher). This reasoning came from the fact that in the Dispatcher logs (from TORCH_SHOW_DISPATCH_TRACE=1
) , fused_rms_norm_backward
is not considered at all for Meta
, while it is in CUDA
.
Meta Log
[call] op=[aten::mul.Tensor], key=[PythonDispatcher]
[redispatchBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
[redispatch] op=[aten::mul.Tensor], key=[PythonDispatcher]
[callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
[callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
[call] op=[aten::empty.memory_format], key=[PythonDispatcher]
[redispatch] op=[aten::empty.memory_format], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[call] op=[aten::empty_strided], key=[PythonDispatcher]
[redispatch] op=[aten::empty_strided], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[callBoxed] op=[aten::detach], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[call] op=[aten::mul.Tensor], key=[PythonDispatcher]
[redispatchBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
[redispatch] op=[aten::mul.Tensor], key=[PythonDispatcher]
[callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
[callBoxed] op=[aten::mul.Tensor], key=[PythonDispatcher]
[call] op=[aten::empty_strided], key=[PythonDispatcher]
[redispatch] op=[aten::empty_strided], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[callBoxed] op=[aten::detach], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[call] op=[aten::sum.dim_IntList], key=[PythonDispatcher]
[redispatchBoxed] op=[aten::sum.dim_IntList], key=[PythonDispatcher]
[redispatch] op=[aten::sum.dim_IntList], key=[PythonDispatcher]
CUDA Log
[call] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
[redispatchBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
[redispatchBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
[callBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
[callBoxed] op=[aten::_fused_rms_norm_backward], key=[PythonDispatcher]
[call] op=[aten::to.dtype], key=[PythonDispatcher]
[call] op=[aten::to.dtype], key=[PythonDispatcher]
[call] op=[aten::to.dtype], key=[PythonDispatcher]
[call] op=[aten::mul.Tensor], key=[PythonDispatcher]
[call] op=[aten::empty_strided], key=[PythonDispatcher]
[redispatch] op=[aten::empty_strided], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[call] op=[aten::mul.Tensor], key=[PythonDispatcher]
[call] op=[aten::empty_strided], key=[PythonDispatcher]
[redispatch] op=[aten::empty_strided], key=[Meta]
[call] op=[aten::detach], key=[Meta]
[call] op=[aten::mul.Tensor], key=[PythonDispatcher]
[call] op=[aten::empty.memory_format], key=[PythonDispatcher]
Now here are my true questions:
-
Is the above intuitive answer correct?
-
Why would AOTAutograd record
rms_norm
differently in the tape, when the forward pass decomposition (as captured by the forward FXGraph) is the same? -
Are there any code pointers, where I can actually debug/print/breakpoint and observe how this is recorded differently?
The PyTorch build is a local build with PyTorch commit (9ad7dd5) dated Jul 28 (A more recent version broke due to what looks like a cuDNN version mismatch, although I’m sure I did something wrong… My fingers are crossed that nothing fundamental changed since then):