Replace Pattern for nn.Modules?

Would FX be able to support nn.Module patterns for replacement using the subgraph_rewriter? The examples / tests show simple functional patterns, but being able to define patterns of modules like nn.Conv2d or nn.BatchNorm2d etc can be really powerful. I see the Conv-BN Fuser tutorial uses direct graph manipulation, but this would be a great candidate to try with the replace_pattern if that can be supported. One consideration is in defining patterns that look for call_module ops with nn.Conv2d targets without hardcoding the kernel or channel sizes in the pattern itself (to allow more general matching based on op type). So more like a reg-ex (*) on the module parameters:

For example:

        class Pattern(torch.nn.Module):
            def __init__(self):
                self.conv = torch.nn.Conv2d(*, *, *)  # match all Conv2d
                self.bn = torch.nn.BatchNorm2d(*)   # match all BatchNorm2d

            def forward(self, x):
                return self.bn(self.conv(x))


        class Replacement(torch.nn.Module):
            def __init__(self):
                self.conv = torch.nn.Conv2d(*, *, *)

            def forward(self, x):
                return self.conv(x)

        tfx.subgraph_rewriter.replace_pattern(module.graph_module, Pattern(), Replacement())

Related topic: Torch.fx replace modules

1 Like