Torch.compile: Generated Triton kernel seems wrong

My original code:

import torch


def square(x):
    x = torch.square(x)
    if x.sum() > 0:
        x = x-1
    return torch.square(x)


opt_square = torch.compile(square)
opt_square(torch.randn(10000, 10000).cuda())

generated triton kernel:

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 100000000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = 1.0
    tmp2 = tmp0 - tmp1
    tmp3 = tmp2 * tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''', device_str='cuda')

That is, for given input:

torch: (x**2-1)**2
triton: (x-1)**2

I cannot reproduce an issue on a 3090 with a current nightly binary:

def square(x):
    x = torch.square(x)
    if x.sum() > 0:
        x = x-1
    return torch.square(x)


opt_square = torch.compile(square)
x = torch.randn(10000, 10000).cuda()
out = opt_square(x)
ref = square(x)

print((out - ref).abs().max())
# tensor(0., device='cuda:0')

I dont see any issue either using your code on my 4060. But I generated and checked triton again. It’s same as above. Could be something silly.

Found it: It’s generating multiple files and kernels and I was only looking at one.