[TorchDynamo] How to generate fx graph with nn.* layers?

Hi Folks,

I try to generate a graph with nn.* layers in the “target” fields of FX nodes?

For example:

class WrappedBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = nn.BatchNorm2d(1)
    def forward(self, x):
        return self.mod(x)

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Conv2d(1, 1, 1)
        self.layer2 = nn.BatchNorm2d(1)
        self.layer3 = nn.Conv2d(1, 1, 1)
        self.nested = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 1, 1),
        )
        self.wrapped = WrappedBatchNorm()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.nested(x)
        x = self.wrapped(x)
        return x

def my_custom_backend(gm, example_inputs):
    print("My Custom Backend:")
    print(gm.graph)
    return gm.forward

#dyn.config.skipfiles_inline_module_allowlist.add(M)
torch._logging.set_logs(dynamo=logging.DEBUG)

mod = M()
mod.eval()
mod = torch.compile(mod, backend=my_custom_backend)
mod(torch.randn(1,1,4,4))

It generates a graph like:

My Custom Backend:
graph():
    %l_x_ : torch.Tensor [#users=1] = placeholder[target=L_x_]
    %l__self___layer1 : [#users=1] = call_module[target=L__self___layer1](args = (%l_x_,), kwargs = {})
    %l__self___layer2 : [#users=1] = call_module[target=L__self___layer2](args = (%l__self___layer1,), kwargs = {})
    %l__self___layer3 : [#users=1] = call_module[target=L__self___layer3](args = (%l__self___layer2,), kwargs = {})
    %l__self___nested_0 : [#users=1] = call_module[target=L__self___nested_0](args = (%l__self___layer3,), kwargs = {})
    %l__self___nested_1 : [#users=1] = call_module[target=L__self___nested_1](args = (%l__self___nested_0,), kwargs = {})
    %l__self___wrapped_mod : [#users=1] = call_module[target=L__self___wrapped_mod](args = (%l__self___nested_1,), kwargs = {})
    return (l__self___wrapped_mod,)

Since I want to do some patten match, I like it generate a graph with “target=nn.*" , e.g.:
%xxx : [#users=1] = call_module[target=nn.Conv2d](args =…)

What should I do? can it trace into the self.xxx call and inline it?

Thanks

It sounds like what you need is the torch IR - this tutorial might help Google Colab

Thank you! This makes sense