My original code:
import torch
def square(x):
x = torch.square(x)
if x.sum() > 0:
x = x-1
return torch.square(x)
opt_square = torch.compile(square)
opt_square(torch.randn(10000, 10000).cuda())
generated triton kernel:
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 100000000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = 1.0
tmp2 = tmp0 - tmp1
tmp3 = tmp2 * tmp2
tl.store(out_ptr0 + (x0), tmp3, xmask)
''', device_str='cuda')
That is, for given input:
torch: (x**2-1)**2
triton: (x-1)**2