Xblock is not defined

Hi,

I hit this error when I tried to compile my model (which is composed of only fully-connected layers, sigmoid&relu activations, BN and reshape). Does anyone know what’s causing this? And it only happens for max-autotune compile mode. Both the default and reduce-overhead compile modes work.

Traceback (most recent call last):
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 937, in build_triton_ir
generator.visit(fn.parse())
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 855, in visit
return super().visit(node)
File “/miniconda/lib/python3.9/ast.py”, line 407, in visit
return visitor(node)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 183, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File “/miniconda/lib/python3.9/ast.py”, line 415, in generic_visit
self.visit(item)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 855, in visit
return super().visit(node)
File “/miniconda/lib/python3.9/ast.py”, line 407, in visit
return visitor(node)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 252, in visit_FunctionDef
has_ret = self.visit_compound_statement(node.body)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 855, in visit
return super().visit(node)
File “/miniconda/lib/python3.9/ast.py”, line 407, in visit
return visitor(node)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 301, in visit_Assign
values = self.visit(node.value)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 855, in visit
return super().visit(node)
File “/miniconda/lib/python3.9/ast.py”, line 407, in visit
return visitor(node)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 757, in visit_Call
args = [self.visit(arg) for arg in node.args]
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 757, in
args = [self.visit(arg) for arg in node.args]
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 855, in visit
return super().visit(node)
File “/miniconda/lib/python3.9/ast.py”, line 407, in visit
return visitor(node)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 188, in visit_List
elts = [self.visit(elt) for elt in node.elts]
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 188, in
elts = [self.visit(elt) for elt in node.elts]
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 855, in visit
return super().visit(node)
File “/miniconda/lib/python3.9/ast.py”, line 407, in visit
return visitor(node)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 325, in visit_Name
return self.get_value(node.id)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 156, in get_value
raise ValueError(f’{name} is not defined’)
ValueError: XBLOCK is not defined

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File “/miniconda/lib/python3.9/concurrent/futures/process.py”, line 246, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File “/miniconda/lib/python3.9/site-packages/torch/_inductor/codecache.py”, line 549, in _worker_compile
kernel.precompile(warm_cache_only_with_cc=cc)
File “/miniconda/lib/python3.9/site-packages/torch/_inductor/triton_ops/autotune.py”, line 69, in precompile
self.launchers = [
File “/miniconda/lib/python3.9/site-packages/torch/_inductor/triton_ops/autotune.py”, line 70, in
self._precompile_config(c, warm_cache_only_with_cc)
File “/miniconda/lib/python3.9/site-packages/torch/_inductor/triton_ops/autotune.py”, line 83, in precompile_config
triton.compile(
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 1620, in compile
next_module = compile(module)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 1549, in
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 962, in ast_to_ttir
mod, _ = build_triton_ir(fn, signature, specialization, constants)
File “/miniconda/lib/python3.9/site-packages/triton/compiler.py”, line 942, in build_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 64:39:
def triton
(arg_A, arg_B, in_ptr2, in_ptr3, out_ptr2):
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = False
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
BLOCK_M : tl.constexpr = 32
BLOCK_N : tl.constexpr = 32
BLOCK_K : tl.constexpr = 16

A = arg_A
B = arg_B

M = 1000
N = 512
K = 151
stride_am = 151
stride_ak = 1
stride_bk = 512
stride_bn = 1

# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N

# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
    if EVEN_K:
        a = tl.load(A)
        b = tl.load(B)
    else:
        a = tl.load(A, mask=rk[None, :] < k, other=0.)
        b = tl.load(B, mask=rk[:, None] < k, other=0.)
    acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
    A += BLOCK_K * stride_ak
    B += BLOCK_K * stride_bk

# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)

# inductor generates a suffix
xindex = idx_n + (512*idx_m)
tmp0 = tl.load(in_ptr2 + (xindex + tl.zeros(mask.shape, tl.int32)), mask)
tmp2_load = tl.load(in_ptr3 + (0))
tmp2 = tl.broadcast_to(tmp2_load, [XBLOCK])`

Could you post a minimal and executable code snippet reproducing the issue so that we could debug it, please?