Mutating input arguments in compiled functions (compiling a fused adam kernel)

As a test case for torch.compile, I am trying to use it to generate a fused Adam weight update with mixed precision support. My code looks like this:

@torch.compile
def simple_mixed_precision_adam(w16, g16, m, v, w32, lr, beta1, beta2, eps):
    # Cast the bf16 grads to fp32
    g32 = g16.to(torch.float32)
    # Do the Adam update (simplified) to compute new m/v state and fp32 weights
    m_new = beta1 * m + (1 - beta1) * g32
    v_new = beta2 * v + (1 - beta2) * (g32 * g32)
    w32_new = w32 - (lr * m_new) / (torch.sqrt(v_new) + eps)

    # Write back the new values onto the same Tensors
    # Note we also overwrite current bf16 weights with new fp32 weight values
    w32.copy_(w32_new)
    w16.copy_(w32_new)
    m.copy_(m_new)
    v.copy_(v_new)

# Set up arguments
nparams = 100 * 1000 * 1000
w32 = torch.randn(nparams, dtype=torch.float32, device="cuda")
w16 = w32.to(torch.bfloat16)
g16 = torch.randn(nparams, dtype=torch.bfloat16, device="cuda")
m = torch.zeros_like(w32)
v = torch.zeros_like(w32)
lr = 0.001
beta1 = 0.9
beta2 = 0.99
eps = 1e-5
# Call the code
simple_mixed_precision_adam(w16, g16, m, v, w32, lr, beta1, beta2, eps)

What I observe is that the compiled execution ends up running three separate triton kernels. If I stare at the compiled triton code long enough (output_code.py from TORCH_COMPILE_DEBUG=1), it looks like three kernels are:

  1. Compute w32_new, cast to bf16, store to w16, exit
  2. Compute mostly everything, only do the stores to w32, m, and a temp buffer (the v_new values, I think?)
  3. Compute v_new, store to v

Put another way, it runs three kernels that mostly redo the computation from the start every time and stores only a subset of the outputs in the epilogue. In one of them, it even stores to a temp buffer that gets thrown away.

I found I can eliminate this behavior by instead returning the new values with just one copy_:

# still do the in-place store of new bf16 weights
w16.copy_(w32_new)
# But return everything else to the caller:
return m_new, v_new, w32_new

Now I get a single kernel that does four stores at the end. But, this is not ideal, since it will allocate 3x output buffers, and those can potentially be quite large! A by-hand fused kernel would certainly update all 4 outputs in-place.

So the question: what is going on here, and is there a way to fix it? How do I get torchinductor to generate code that stores all 4 outputs back to the input? Thanks.

Hmm, have you tried running this code on a nightly? Start Locally | PyTorch

When I run the above snippet with the debug env var TORCH_COMPILE_DEBUG=1 python tmp.py (this env var prints out the generated triton code in an easy-to-find-way), I see a single triton kernel generated:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile

from torch import device, empty, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_hirsheybar/q4/cq4kbd5xdt5xogjikzypsicadp5snd7vh3xtnrjltwl73mlnszav.py
# Source Nodes: [add_2, copy_, copy__1, copy__2, copy__3, g32, m_new, mul, mul_1, mul_2, mul_3, mul_4, mul_5, sqrt, truediv, v_new, w32_new], Original ATen: [aten._to_copy, aten.add, aten.copy, aten.div, aten.mul, aten.sqrt, aten.sub]
# add_2 => add_2
# copy_ => copy
# copy__1 => copy_1
# copy__2 => copy_2
# copy__3 => copy_3
# g32 => convert_element_type
# m_new => add
# mul => mul
# mul_1 => mul_1
# mul_2 => mul_2
# mul_3 => mul_3
# mul_4 => mul_4
# mul_5 => mul_5
# sqrt => sqrt
# truediv => div
# v_new => add_1
# w32_new => sub
triton_poi_fused__to_copy_add_copy_div_mul_sqrt_sub_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import AutotuneHint, pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers

@pointwise(
    size_hints=[134217728],
    filename=__file__,
    meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*bf16', 3: '*fp32', 4: '*bf16', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'mutated_arg_names': ['in_ptr0', 'in_ptr1', 'in_ptr3', 'out_ptr4', 'out_ptr5', 'out_ptr6', 'out_ptr7'], 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_copy_div_mul_sqrt_sub_0', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr4, out_ptr5, out_ptr6, out_ptr7, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp4 = tl.load(in_ptr2 + (x0), xmask).to(tl.float32)
    tmp11 = tl.load(in_ptr3 + (x0), xmask)
    tmp2 = 0.9
    tmp3 = tmp1 * tmp2
    tmp5 = tmp4.to(tl.float32)
    tmp6 = 0.09999999999999998
    tmp7 = tmp5 * tmp6
    tmp8 = tmp3 + tmp7
    tmp9 = 0.001
    tmp10 = tmp8 * tmp9
    tmp12 = 0.99
    tmp13 = tmp11 * tmp12
    tmp14 = tmp5 * tmp5
    tmp15 = 0.010000000000000009
    tmp16 = tmp14 * tmp15
    tmp17 = tmp13 + tmp16
    tmp18 = tl.sqrt(tmp17)
    tmp19 = 1e-05
    tmp20 = tmp18 + tmp19
    tmp21 = tmp10 / tmp20
    tmp22 = tmp0 - tmp21
    tmp23 = tmp22.to(tl.float32)
    tl.store(out_ptr4 + (x0), tmp23, xmask)
    tl.store(out_ptr5 + (x0), tmp8, xmask)
    tl.store(out_ptr6 + (x0), tmp17, xmask)
    tl.store(out_ptr7 + (x0), tmp22, xmask)
''')

import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1 = args
    args.clear()
    s0 = arg0_1
    assert_size_stride(arg1_1, (s0, ), (1, ))
    assert_size_stride(arg2_1, (s0, ), (1, ))
    assert_size_stride(arg3_1, (s0, ), (1, ))
    assert_size_stride(arg4_1, (s0, ), (1, ))
    assert_size_stride(arg5_1, (s0, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [add_2, copy_, copy__1, copy__2, copy__3, g32, m_new, mul, mul_1, mul_2, mul_3, mul_4, mul_5, sqrt, truediv, v_new, w32_new], Original ATen: [aten._to_copy, aten.add, aten.copy, aten.div, aten.mul, aten.sqrt, aten.sub]
        stream0 = get_cuda_stream(0)
        triton_poi_fused__to_copy_add_copy_div_mul_sqrt_sub_0.run(arg5_1, arg3_1, arg2_1, arg4_1, arg1_1, arg3_1, arg4_1, arg5_1, s0, grid=grid(s0), stream=stream0)
        del arg1_1
        del arg2_1
        del arg3_1
        del arg4_1
        del arg5_1
        return ()


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = 100000000
    arg1_1 = rand_strided((100000000, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg2_1 = rand_strided((100000000, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
    arg3_1 = rand_strided((100000000, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg4_1 = rand_strided((100000000, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg5_1 = rand_strided((100000000, ), (1, ), device='cuda:0', dtype=torch.float32)
    return print_performance(lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1]), times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)

Thanks – I can confirm that the behavior has changed in the latest nightly. The results above were on the release 2.1 version. I just pulled 2.2.0.dev20231028+cu121 and see what you do – there’s a single generated kernel with all four stores in the epilogue.

This suggests that something changed with torchinductor going 2.1->2.2 that improved kernel generation for this case. (The captured graph doesn’t look like it has changed, so it’s a definitely something in code generation.)

I am curious what has changed / if there’s a way to trigger this behavior in 2.1, but at the least also glad to see things are improving!