introduction
I recently tried the torch.compile function with mode ‘max-autotune’ and have some questions about the timing of autotune for the mm operator.
My test program is as follows
import torch
import torch.nn.functional as F
from torch import nn
class TestMLP(nn.Module):
def __init__(self, intermediate_size=3420, hidden_size=1280):
super().__init__()
self.proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=torch.bfloat16, device='cuda')
def forward(self, hidden_states):
return self.proj(hidden_states)
if __name__ == '__main__':
mlp = TestMLP()
mlp = torch.compile(mlp, mode="max-autotune")
seq_lens = [560, 1024, 2048]
for seq_len in seq_lens:
hidden_states = torch.randn(seq_len, 3420, device='cuda', dtype=torch.bfloat16)
torch._dynamo.mark_dynamic(hidden_states, 0)
output = mlp(hidden_states)
print(f"{output.shape=}", flush=True)
The output of the test program is as follows
AUTOTUNE mm(560x3420, 3420x1280)
triton_mm_11 0.0498 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_10 0.0583 ms 85.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_14 0.0593 ms 84.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_7 0.0657 ms 75.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_9 0.0667 ms 74.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_13 0.0669 ms 74.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_16 0.0758 ms 65.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
mm 0.0792 ms 62.9%
triton_mm_1 0.0878 ms 56.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_8 0.0878 ms 56.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.5051 seconds and 5.1089 seconds precompiling for 20 choices
AUTOTUNE mm(1280x560, 560x3420)
triton_mm_35 0.0383 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_32 0.0411 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_30 0.0448 ms 85.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_28 0.0458 ms 83.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_33 0.0458 ms 83.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_29 0.0472 ms 81.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_37 0.0495 ms 77.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_26 0.0571 ms 67.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_25 0.0619 ms 61.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_36 0.0642 ms 59.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.6340 seconds and 4.7325 seconds precompiling for 20 choices
output.shape=torch.Size([560, 1280])
output.shape=torch.Size([1024, 1280])
AUTOTUNE mm(2048x3420, 3420x1280)
triton_mm_49 0.1237 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_54 0.1260 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_51 0.1407 ms 87.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_47 0.1423 ms 87.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_48 0.1598 ms 77.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_52 0.1600 ms 77.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_45 0.1791 ms 69.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_55 0.1796 ms 68.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_56 0.2154 ms 57.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_46 0.2214 ms 55.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.6472 seconds and 0.0729 seconds precompiling for 20 choices
AUTOTUNE mm(1280x2048, 2048x3420)
triton_mm_70 0.1017 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_73 0.1017 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_71 0.1205 ms 84.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_68 0.1223 ms 83.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_66 0.1250 ms 81.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_75 0.1275 ms 79.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_mm_67 0.1309 ms 77.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
mm 0.1547 ms 65.7%
triton_mm_64 0.1625 ms 62.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_69 0.1685 ms 60.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.6176 seconds and 0.0887 seconds precompiling for 20 choices
output.shape=torch.Size([2048, 1280])
It seems that seq_len = 560,1024 share the same triton mm, but seq_len = 2048 re-autotunes a triton mm operator. I guess the regeneration may be related to the tiling rule. I want to know which code in pytorch determines whether a new triton mm needs to be autotuned for a new shape. After knowing the rule, we can do some warmup operations to avoid autotuning during online serving, which affects the user experience.
environment
torch version: 2.6.0+cu124
platform: A100