I have following simple test case for which I see OuterLoopFusedSchedulerNode generated
def test_group_norm():
from torch._inductor import config
# Force multple scheduler nodes creation to fuse them
config.realize_opcount_threshold = 1
config.inplace_buffers = False
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.group_norm = torch.nn.GroupNorm(2, 16)
def forward(self, x):
return self.group_norm(x)
mod = M().eval()
x = torch.randn(2, 16, 32, 32)
with torch.no_grad():
expected = mod(x)
compiled_m = torch.compile(mod, backend="inductor", fullgraph=True)
actual = compiled_m(x)
assert torch.allclose(
expected.cpu(), actual.cpu(), rtol=0.000001, atol=0.001
), "Test failed: Output does not match expected output"
However in generated code we can see that loops like following should have been fused to single outer loop
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
so why this does not happen? is this intentional or a bug?
extern "C" void kernel(const float* in_ptr0,
const float* in_ptr1,
const float* in_ptr2,
float* out_ptr0,
float* out_ptr1,
float* out_ptr4,
float* out_ptr6,
float* out_ptr8)
{
// This is outerloop fused node buf0_buf1_buf3_buf5_buf6_buf4_buf7
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
{
Welford<float> tmp_acc0 = Welford<float>();
Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
static WeightRecp<at::vec::Vectorized<float>> weight_recps(static_cast<long>(512L));
for(long x1=static_cast<long>(0L); x1<static_cast<long>(8192L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (8192L*x0)), 16);
tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &weight_recps);
}
tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean);
out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2);
}
}
}
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
auto tmp0 = out_ptr1[static_cast<long>(x0)];
auto tmp1 = static_cast<float>(8192.0);
auto tmp2 = tmp0 / tmp1;
auto tmp3 = static_cast<float>(1e-05);
auto tmp4 = decltype(tmp2)(tmp2 + tmp3);
auto tmp5 = 1 / std::sqrt(tmp4);
out_ptr4[static_cast<long>(x0)] = tmp5;
}
}
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
{
for(long x1=static_cast<long>(0L); x1<static_cast<long>(8192L); x1+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (8192L*x0)), 16);
auto tmp1 = out_ptr0[static_cast<long>(x0)];
auto tmp4 = out_ptr4[static_cast<long>(x0)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 - tmp2;
auto tmp5 = at::vec::Vectorized<float>(tmp4);
auto tmp6 = tmp3 * tmp5;
tmp6.store(out_ptr6 + static_cast<long>(x1 + (8192L*x0)));
}
}
}
// This is Fused node buf8_buf9
{
#pragma GCC ivdep
for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
{
#pragma GCC ivdep
for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L); x1+=static_cast<long>(1L))
{
for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr6 + static_cast<long>(x2 + (1024L*x1) + (16384L*x0)), 16);
auto tmp1 = in_ptr1[static_cast<long>(x1)];
auto tmp4 = in_ptr2[static_cast<long>(x1)];
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
auto tmp5 = at::vec::Vectorized<float>(tmp4);
auto tmp6 = tmp3 + tmp5;
tmp6.store(out_ptr8 + static_cast<long>(x2 + (1024L*x1) + (16384L*x0)));
}
}
}
}
}