Why is inductor re-recording my function with every call

I have a fairly straightforward Transformer implementation that works fine with torch.compile compiling the .forward() pass when input_ids size is 1, however, when I invoke it with my custom triton attention kernel inductor is re-recording the function with every call, even after the warmup, and even with identical input tensor shapes.

I’ve been trying to figure out the correct logging voodoo to get it to tell me what the cause of the re-recording is but the best I can get it to do is to verify that I am indeed recording every call (I have also verified this in the profiler).

I’m using this:

import torch
import torch._logging
import logging
torch._logging.set_logs(
    dynamo=logging.INFO, inductor=logging.INFO,
    recompiles=True, recompiles_verbose=True, cudagraphs=True, graph_breaks=True)

and I get this output (one line per invocation of my .forward():

V0304 23:53:37.255000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 10
TheV0304 23:53:37.468000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 11
 tallestV0304 23:53:37.665000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 12
 mountainV0304 23:53:37.846000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 13
 inV0304 23:53:38.077000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 14
 theV0304 23:53:38.334000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 15
 worldV0304 23:53:38.589000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 16
 isV0304 23:53:38.800000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 17
 MountV0304 23:53:39.041000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 18
 EverestV0304 23:53:39.264000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 19
,V0304 23:53:39.482000 15408 torch/_inductor/cudagraph_trees.py:2157] [__cudagraphs] Recording function 0 of graph recording id 20

Any help about how I can get pytorch to tell me why I am re-recording every call would be great, or any insight into what else besides tensor shapes are likely to be causing it.

Do you have a script we could run to reproduce this? I don’t know what’s up off of the top of my head. Maybe @Elias_Ellison?

Thanks for the reply @richard, but I ended up solving my problem.

The issue was I was storing a tensor as a member variable and then using that tensor inside the compiled code. I don’t fully understand why that triggered the re-record, but changing the code so that that tensor was passed in as a parameter to my forward pass fixed the problem.

To be more specific, I have a KVCache that needs you to allocate indices ahead of time. The old code was something like this:

def KVCache.reserve_indices():
self.reserved_indices = …

def forward(kv_cache):
… use kv_cache.reserved_indices

To fix it, I simply changed KVCache.reserved_indices() to return the indices, and then passed them as an argument into forward() and it stopped re-recording my function.

It still would be useful to know what logging settings I could set for inductor to tell me WHY it was re-recording the function as that would have saved me a lot of time, but for the moment my code is working.

@joev was your kv-cache an nn.Module? (I’m trying to reproduce your issue so that we can write a bug report to the team)