Torch.compile cpp backend not fusing OuterLoopFusedSchedulerNode properly

I have following simple test case for which I see OuterLoopFusedSchedulerNode generated

def test_group_norm():
    from torch._inductor import config

    # Force multple scheduler nodes creation to fuse them
    config.realize_opcount_threshold = 1
    config.inplace_buffers = False
    class M(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.group_norm = torch.nn.GroupNorm(2, 16)

        def forward(self, x):
            return self.group_norm(x)

    mod = M().eval()
    x = torch.randn(2, 16, 32, 32)
    with torch.no_grad():
        expected = mod(x)
        compiled_m = torch.compile(mod, backend="inductor", fullgraph=True)
        actual = compiled_m(x)
        assert torch.allclose(
            expected.cpu(), actual.cpu(), rtol=0.000001, atol=0.001
        ), "Test failed: Output does not match expected output"

However in generated code we can see that loops like following should have been fused to single outer loop

for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))

so why this does not happen? is this intentional or a bug?

extern "C" void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr4,
                       float* out_ptr6,
                       float* out_ptr8)
{
    // This is outerloop fused node buf0_buf1_buf3_buf5_buf6_buf4_buf7
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                static WeightRecp<at::vec::Vectorized<float>> weight_recps(static_cast<long>(512L));
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(8192L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (8192L*x0)), 16);
                    tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &weight_recps);
                }
                tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean);
                out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2);
            }
        }
    }
    {
        #pragma omp simd simdlen(8) 
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = out_ptr1[static_cast<long>(x0)];
            auto tmp1 = static_cast<float>(8192.0);
            auto tmp2 = tmp0 / tmp1;
            auto tmp3 = static_cast<float>(1e-05);
            auto tmp4 = decltype(tmp2)(tmp2 + tmp3);
            auto tmp5 = 1 / std::sqrt(tmp4);
            out_ptr4[static_cast<long>(x0)] = tmp5;
        }
    }
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
        {
            for(long x1=static_cast<long>(0L); x1<static_cast<long>(8192L); x1+=static_cast<long>(16L))
            {
                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (8192L*x0)), 16);
                auto tmp1 = out_ptr0[static_cast<long>(x0)];
                auto tmp4 = out_ptr4[static_cast<long>(x0)];
                auto tmp2 = at::vec::Vectorized<float>(tmp1);
                auto tmp3 = tmp0 - tmp2;
                auto tmp5 = at::vec::Vectorized<float>(tmp4);
                auto tmp6 = tmp3 * tmp5;
                tmp6.store(out_ptr6 + static_cast<long>(x1 + (8192L*x0)));
            }
        }
    }
    // This is Fused node buf8_buf9
    {
        #pragma GCC ivdep
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
        {
            #pragma GCC ivdep
            for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L); x1+=static_cast<long>(1L))
            {
                for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L))
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr6 + static_cast<long>(x2 + (1024L*x1) + (16384L*x0)), 16);
                    auto tmp1 = in_ptr1[static_cast<long>(x1)];
                    auto tmp4 = in_ptr2[static_cast<long>(x1)];
                    auto tmp2 = at::vec::Vectorized<float>(tmp1);
                    auto tmp3 = tmp0 * tmp2;
                    auto tmp5 = at::vec::Vectorized<float>(tmp4);
                    auto tmp6 = tmp3 + tmp5;
                    tmp6.store(out_ptr8 + static_cast<long>(x2 + (1024L*x1) + (16384L*x0)));
                }
            }
        }
    }
}

Hi @vivekvpandya, thanks for starting the discussion!

Before diving into my understanding of the issue, I’d like to briefly explain how we generate code for OuterLoopFusedSchedulerNode in the Inductor C++ backend:

We first generate code for each individual SchedulerNode within the OuterLoopFusedSchedulerNode. After that, we check whether any loop levels from the generated kernels can be merged at the outer loop level. Specifically, we examine whether two loop levels from different kernels share the same loop attributes—such as whether the loop is vectorized and the vectorization step size. If the attributes match, we merge the loop levels.

Looking at the generated code you posted, it seems there are three loop levels. The middle one:

#pragma omp simd simdlen(8) 
for (long x0 = static_cast<long>(0L); x0 < static_cast<long>(4L); x0 += static_cast<long>(1L))

appears to have different loop attributes compared to the loop level before and after it. That might be the reason why outer loop fusion didn’t take effect.

If needed, I’m happy to take a closer look—please share the PyTorch version, CPU info, and the command you used to run it.

Thanks @leslie-fang-intel for quicky reply. I am on pytorch 2.4 with CPU Model name: AMD EPYC 9124 16-Core Processor.

I debugged and found that node is marked as OuterLoopFusedSchedulerNode but it fails to merge outer loop and fall backs to normal code gen at pytorch/torch/_inductor/codegen/cpp.py at ee1b6804381c57161c477caa380a840a84167676 · pytorch/pytorch · GitHub

I also checked with torch 2.6 it also fails to merge outer loop.