Would FX be able to support nn.Module patterns for replacement using the subgraph_rewriter
? The examples / tests show simple functional patterns, but being able to define patterns of modules like nn.Conv2d or nn.BatchNorm2d etc can be really powerful. I see the Conv-BN Fuser tutorial uses direct graph manipulation, but this would be a great candidate to try with the replace_pattern
if that can be supported. One consideration is in defining patterns that look for call_module ops with nn.Conv2d targets without hardcoding the kernel or channel sizes in the pattern itself (to allow more general matching based on op type). So more like a reg-ex (*) on the module parameters:
For example:
class Pattern(torch.nn.Module):
def __init__(self):
self.conv = torch.nn.Conv2d(*, *, *) # match all Conv2d
self.bn = torch.nn.BatchNorm2d(*) # match all BatchNorm2d
def forward(self, x):
return self.bn(self.conv(x))
class Replacement(torch.nn.Module):
def __init__(self):
self.conv = torch.nn.Conv2d(*, *, *)
def forward(self, x):
return self.conv(x)
tfx.subgraph_rewriter.replace_pattern(module.graph_module, Pattern(), Replacement())
Related topic: Torch.fx replace modules