I have a fairly straightforward Transformer implementation that works fine with torch.compile compiling the .forward() pass when input_ids size is 1, however, when I invoke it with my custom triton attention kernel inductor is re-recording the function with every call, even after the warmup, and even with identical input tensor shapes.
I’ve been trying to figure out the correct logging voodoo to get it to tell me what the cause of the re-recording is but the best I can get it to do is to verify that I am indeed recording every call (I have also verified this in the profiler).
I’m using this:
import torch
import torch._logging
import logging
torch._logging.set_logs(
dynamo=logging.INFO, inductor=logging.INFO,
recompiles=True, recompiles_verbose=True, cudagraphs=True, graph_breaks=True)
and I get this output (one line per invocation of my .forward():
V0304 23:53:37.255000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 10
TheV0304 23:53:37.468000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 11
tallestV0304 23:53:37.665000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 12
mountainV0304 23:53:37.846000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 13
inV0304 23:53:38.077000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 14
theV0304 23:53:38.334000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 15
worldV0304 23:53:38.589000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 16
isV0304 23:53:38.800000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 17
MountV0304 23:53:39.041000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 18
EverestV0304 23:53:39.264000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 19
,V0304 23:53:39.482000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 20
Any help about how I can get pytorch to tell me why I am re-recording every call would be great, or any insight into what else besides tensor shapes are likely to be causing it.