The feature, motivation and pitch
How can I get a runnable copy of the backwards graph when running torch.compile
?
For example:
@torch.compile
def f(x):
out = x.sin() + x.cos()
return out
x = torch.ones(2, requires_grad=True).cuda()
out = f(x)
If I run this with TORCH_LOGS=all
, I can see that the backwards graph is traced:
===== Backward graph 0 =====
<eval_with_key>.36 class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2]", tangents_1: "f32[2]"):
# File: /notebooks/clean-repos/triton-autodiff/test_compile.py:32 in f, code: tmp1 = x.sin() + x.cos()
sin: "f32[2]" = torch.ops.aten.sin.default(primals_1)
cos: "f32[2]" = torch.ops.aten.cos.default(primals_1); primals_1 = None
neg: "f32[2]" = torch.ops.aten.neg.default(sin); sin = None
mul: "f32[2]" = torch.ops.aten.mul.Tensor(tangents_1, neg); neg = None
mul_1: "f32[2]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
# File: /notebooks/clean-repos/triton-autodiff/test_compile.py:32 in f, code: tmp1 = x.sin() + x.cos()
add_1: "f32[2]" = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None
return [add_1]
However, if I look at the inductor
cache, I see only the forward kernel (output.py
):
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: torch_compile/cache/rv/crviokfbw7lr4fi4yf2kqhtyjostdohc32ulg7vlfbyprvjinxaa.py
# Source Nodes: [cos, sin, tmp1], Original ATen: [aten.add, aten.cos, aten.sin]
# cos => cos
# sin => sin
# tmp1 => add
triton_poi_fused_add_cos_sin_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(
size_hints=[2],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cos_sin_0', 'mutated_arg_names': [], 'no_x_dim': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.sin(tmp0)
tmp2 = tl.cos(tmp0)
tmp3 = tmp1 + tmp2
tl.store(out_ptr0 + (x0), tmp3, xmask)
''')
...
How can I get the codegen'ed
equivalents for the backwards graph, and the torch.autograd.Function
that calls these forward / backward kernels?