I have a custom triton kernel that is wrapped by an Autograd function. When torch compiling this function, the outputs are incorrect (and do not match the outputs of the non-compiled version). The outputs are correct on when the function is run without compiling.
I have localized the issue. Here is the Autograd function definition:
class FusedChunkLinearAttentionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, scale, initial_state, output_final_state):
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
ctx.scale = scale
# Convert PyTorch dtype to Triton dtype
if q.dtype == torch.float16:
triton_dtype = tl.float16
elif q.dtype == torch.bfloat16:
triton_dtype = tl.bfloat16
else:
triton_dtype = tl.float32
# Initialize the output tensors
BK = min(triton.next_power_of_2(d_head_qk), 64)
NK = triton.cdiv(d_head_qk, BK)
o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)
if output_final_state:
final_state = q.new_empty(
batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False
)
else:
final_state = None
grid = lambda args: ( # noqa: E731
triton.cdiv(v.shape[-1], args["BV"]),
NK,
batch_size * n_heads,
)
fused_chunk_linear_attn_fwd_kernel[grid](
q,
k,
v,
o,
initial_state,
final_state,
scale,
q.stride(1),
q.stride(2),
q.stride(3),
v.stride(1),
v.stride(2),
v.stride(3),
batch_size,
n_heads,
seq_len,
d_head_qk,
d_head_v,
BK=BK,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=output_final_state,
DTYPE=triton_dtype,
)
ctx.save_for_backward(q, k, v, initial_state)
ctx.triton_dtype = triton_dtype
# This code returns incorrect outputs when running with torch compile
# o = o.sum(0)
# return o.to(q.dtype), final_state
# If I break the compilation graph with print(), then the code returns correct outputs
if NK > 1:
print("") # this is needed to break the torch compile graph
o = o.sum(0)
return o.to(q.dtype), final_state
else:
# in our setting we can avoid this since NK will be one
return o[0].to(q.dtype), final_state
The issue stems from the final
o.sum(0)
since it sums up values that are computed by independent threadblocks on the GPU. torch compile tries to compile the full function into a single graph, yielding incorrect values. I speculate that it probably tries to do the sum before all the values are actually available.
Instead if I force break the graph before the sum operation, then the outputs are correct and match the eager mode values.
Two questions:
- Is that behaviour expected that the outputs are incorrect in this case? I would expect torch compile to wait until the triton kernel fully finishes and correctly compute the sum.
- If yes, is there a more principled way to break the graph instead of using a print() function?