How to avoid check_object_id (ID_MATCH) on nn.Module objects?

I got stuck with torch.compile introducing a guard for each object derived from nn.Module while the logic of the computation depends only on the attributes of that module. This causes multiple recompilations in my current model (very many).
Here is a MWE:

import torch
import torch.nn as nn
from torch import Tensor
import logging
torch._logging.set_logs(dynamo=logging.DEBUG)
torch._logging.set_logs(guards=True)

class A(nn.Module):
    def __init__(self):
        super().__init__()
        self.b = 1.0
        
    def forward(self, x):
        return x.sin().sin() + self.b
    
    
def foo(a:A, x:Tensor):
    return a.forward(x)

compile_args = dict(fullgraph=True, dynamic = True, backend = "inductor", options={'group_fusion':True, 'force_same_precision':True, 'disable_cpp_codegen':True, 'trace.graph_diagram':True, "triton.cudagraphs": False})

x = torch.rand(10)
a = A()
foo(a,x)
torch.compile(foo, **compile_args)(a, x)

What I am concerned with is that the output contains:

    torch/_dynamo/guards.py:2148] [0/0] [__guards] | | +- ID_MATCH: ___check_obj_id(L['a'], 140139899637648)                      # return a.forward(x)  # quant/id_check.py:18 in foo

This checks that a is the same object, while the same function would work for any object of that class.

I was wondering why this ID_MATCH happens on nn.Module? It does not happen if A is not derived from it. How does dynamo decide which objects are good and which are bad (must have same id rather than contents)? I would also like to ask if there is a way to disable this ID_MATCH selectively while keeping A derived from nn.Module. Thanks for any hints.

I found a relevant problem here was solved by restructuring the code such that torch.compile does not get to see those Modules, only tensors:

Such solution would greatly reduce flexibility of prototyping and experimenting in my case.

After additional tests, it seems that despite the guard on the object id of a, there is no recompilation when I use the function with a different object. I’ve checked with

torch._logging.set_logs(recompiles=True)

It does not print any recompilations, and also by timing the execution there is no signs of that. It seems I have falsely identified ___check_obj_id on nn.Module as a reason for recompilation of my real use case. However, I don’t understand why this guard is present then?

I have tried now with

torch._logging.set_logs(recompiles=True, recompiles_verbose = True, fusion = True)

it repeatedly print fusion attempts, suggesting that it tries to recompile the graph with each call. But recompiles reports nothing.

Here is a complete code, where I added timing of compilation and benchmarking of compiled function.

import torch
import torch.nn as nn
from torch import Tensor
import logging
# torch._logging.set_logs(guards=True)
torch._logging.set_logs(recompiles=True, recompiles_verbose = True, fusion = True)

class A(nn.Module):
    def __init__(self, v:float):
        super().__init__()
        self.b = v
        
    def forward(self, x):
        return (x.sin()+self.b).sin()
       
def forward(net, x:Tensor):
    return net(x).sum()


dev = torch.device('cuda:0')

compile_args = dict(fullgraph=True, dynamic = True, backend = "inductor", options={'group_fusion':True, 'force_same_precision':True, 'disable_cpp_codegen':False, 'trace.graph_diagram':True, "triton.cudagraphs": False})
cforward = torch.compile(forward, **compile_args)

def compute(net):
    y = cforward(net, x)
    y.backward()

start_e = torch.cuda.Event(enable_timing=True)
end_e = torch.cuda.Event(enable_timing=True)
x = torch.rand((100,100,100), device =dev, requires_grad=True)

nets = [A(float(j)) for j in range(3)]
# nets = [nn.Sequential(*[A(float(j)) for j in range(3)]) for i in range(3)]

# WARMUP/ COMPILE
s = torch.cuda.Stream(device=dev)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.device(dev):
    with torch.cuda.stream(s):  # warmup on a side stream, according to examples
        for i in range(3):
            print(f"_______________Net {i}_______________")    
            x.grad = None
            start_e.record()
            compute(nets[i])
            end_e.record()
            torch.cuda.synchronize()
            t = start_e.elapsed_time(end_e)
            print(f'Time: {t} ms')
            print(x.grad.mean())
            
    torch.cuda.current_stream().wait_stream(s)
    G = torch.cuda.CUDAGraph()
    with torch.cuda.graph(G):
        x.grad = None
        for i in range(3):
            compute(nets[i])

    for i in range(3):
        print(f"_______________Compiled Graphed Iteration {i}_______________")    
        tt = 0
        for j in range(100):
            start_e.record()
            G.replay()
            end_e.record()
            torch.cuda.synchronize()
            t = start_e.elapsed_time(end_e)
            tt += t
        print(f'Time: {tt/100} ms')
        print(x.grad.mean())

It prints

_______________Net 0_______________
V0902 13:17:17.227000 140492323473216 torch/_inductor/scheduler.py:1820] [0/0] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0902 13:17:17.227000 140492323473216 torch/_inductor/scheduler.py:627] [0/0] [__fusion] cannot fuse buf0 with buf1: numel/rnumel mismatch (reduce) (123, 1), (((s0**3 + 122)//123), 123)
V0902 13:17:17.227000 140492323473216 torch/_inductor/scheduler.py:2133] [0/0] [__fusion] found 0 possible fusions
V0902 13:17:17.228000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion] completed fusion round (1/10): fused 2 nodes into 2 nodes
V0902 13:17:17.228000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion] 
V0902 13:17:17.228000 140492323473216 torch/_inductor/scheduler.py:1834] [0/0] [__fusion] ===== fusion complete (1 iterations) =====
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1820] [0/0] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:2133] [0/0] [__fusion] found 0 possible fusions
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1827] [0/0] [__fusion] 
V0902 13:17:18.419000 140492323473216 torch/_inductor/scheduler.py:1834] [0/0] [__fusion] ===== fusion complete (1 iterations) =====
Time: 2500.05810546875 ms
tensor(0.7457, device='cuda:0')
_______________Net 1_______________
V0902 13:17:19.046000 140492323473216 torch/_inductor/scheduler.py:1820] [0/1] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0902 13:17:19.046000 140492323473216 torch/_inductor/scheduler.py:627] [0/1] [__fusion] cannot fuse buf0 with buf1: numel/rnumel mismatch (reduce) (123, 1), (((s0**3 + 122)//123), 123)
V0902 13:17:19.046000 140492323473216 torch/_inductor/scheduler.py:2133] [0/1] [__fusion] found 0 possible fusions
V0902 13:17:19.047000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion] completed fusion round (1/10): fused 2 nodes into 2 nodes
V0902 13:17:19.047000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion] 
V0902 13:17:19.047000 140492323473216 torch/_inductor/scheduler.py:1834] [0/1] [__fusion] ===== fusion complete (1 iterations) =====
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1820] [0/1] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:2133] [0/1] [__fusion] found 0 possible fusions
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1827] [0/1] [__fusion] 
V0902 13:17:19.260000 140492323473216 torch/_inductor/scheduler.py:1834] [0/1] [__fusion] ===== fusion complete (1 iterations) =====
Time: 517.2654418945312 ms
tensor(0.1222, device='cuda:0')
_______________Net 2_______________
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:1820] [0/2] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:627] [0/2] [__fusion] cannot fuse buf0 with buf1: numel/rnumel mismatch (reduce) (123, 1), (((s0**3 + 122)//123), 123)
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:2133] [0/2] [__fusion] found 0 possible fusions
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion] completed fusion round (1/10): fused 2 nodes into 2 nodes
V0902 13:17:19.534000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion] 
V0902 13:17:19.535000 140492323473216 torch/_inductor/scheduler.py:1834] [0/2] [__fusion] ===== fusion complete (1 iterations) =====
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1820] [0/2] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:2133] [0/2] [__fusion] found 0 possible fusions
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1827] [0/2] [__fusion] 
V0902 13:17:19.746000 140492323473216 torch/_inductor/scheduler.py:1834] [0/2] [__fusion] ===== fusion complete (1 iterations) =====
Time: 446.4669189453125 ms

While I expect to see compilation of the first net only and then reusing the same compiled code. However the latter attempts are much faster than the first one, so likely it does reuse some compiled code. That’s confusing. Do I misunderstand something or it is really an issue / several issues?