Hi, I am trying to replace a given set of math operators automatically using the FX tracer. For instance, I would like to replace the + operator used in the residual blocks of a resnet model by a custom CustomAdd layer.
EDIT: I think I found a solution by adding a CustomAdd layer on-the-fly to the GraphModule returned by symbolic_trace, see code below.
import torch
from torchvision.models.resnet import resnet18, ResNet18_Weights
from torch.fx import symbolic_trace
import torch.nn as nn
class CustomAdd(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Replacement operator"""
return torch.sum(torch.stack([a, b], dim=0), dim=0)
def search_and_replace(module: nn.Module) -> nn.Module:
traced = symbolic_trace(module)
op_idx = 0 # Count the number of replacements
for n in traced.graph.nodes:
if n.target.__repr__() == "<built-in function add>":
# This is a ugly way of detecting + operator, I wonder if something better can be done
with traced.graph.inserting_after(n):
# Add layer to the traced module on the fly
module.add_module(f"custom_add_{op_idx}", CustomAdd())
new_node = traced.graph.call_module(f"custom_add_{op_idx}", n.args, n.kwargs)
n.replace_all_uses_with(new_node)
# Remove the old node from the graph
traced.graph.erase_node(n)
# Increase counter to add one layer per addition
op_idx +=1
traced.recompile()
return traced
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
print(search_and_replace(model))
Output:
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Module(
(0): Module(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): Module(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Module(
(0): Module(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Module(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Module(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Module(
(0): Module(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Module(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Module(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Module(
(0): Module(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Module(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Module(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
(stacked_adder_0): CustomAdd()
(stacked_adder_1): CustomAdd()
(stacked_adder_2): CustomAdd()
(stacked_adder_3): CustomAdd()
(stacked_adder_4): CustomAdd()
(stacked_adder_5): CustomAdd()
(stacked_adder_6): CustomAdd()
(stacked_adder_7): CustomAdd()
)
def forward(self, x : torch.Tensor) -> torch.Tensor:
conv1 = self.conv1(x); x = None
bn1 = self.bn1(conv1); conv1 = None
relu = self.relu(bn1); bn1 = None
maxpool = self.maxpool(relu); relu = None
layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
layer1_0_bn1 = getattr(self.layer1, "0").bn1(layer1_0_conv1); layer1_0_conv1 = None
layer1_0_relu = getattr(self.layer1, "0").relu(layer1_0_bn1); layer1_0_bn1 = None
layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_relu); layer1_0_relu = None
layer1_0_bn2 = getattr(self.layer1, "0").bn2(layer1_0_conv2); layer1_0_conv2 = None
stacked_adder_0 = self.stacked_adder_0(layer1_0_bn2, maxpool); layer1_0_bn2 = maxpool = None
layer1_0_relu_1 = getattr(self.layer1, "0").relu(stacked_adder_0); stacked_adder_0 = None
layer1_1_conv1 = getattr(self.layer1, "1").conv1(layer1_0_relu_1)
layer1_1_bn1 = getattr(self.layer1, "1").bn1(layer1_1_conv1); layer1_1_conv1 = None
layer1_1_relu = getattr(self.layer1, "1").relu(layer1_1_bn1); layer1_1_bn1 = None
layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_relu); layer1_1_relu = None
layer1_1_bn2 = getattr(self.layer1, "1").bn2(layer1_1_conv2); layer1_1_conv2 = None
stacked_adder_1 = self.stacked_adder_1(layer1_1_bn2, layer1_0_relu_1); layer1_1_bn2 = layer1_0_relu_1 = None
layer1_1_relu_1 = getattr(self.layer1, "1").relu(stacked_adder_1); stacked_adder_1 = None
layer2_0_conv1 = getattr(self.layer2, "0").conv1(layer1_1_relu_1)
layer2_0_bn1 = getattr(self.layer2, "0").bn1(layer2_0_conv1); layer2_0_conv1 = None
layer2_0_relu = getattr(self.layer2, "0").relu(layer2_0_bn1); layer2_0_bn1 = None
layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_relu); layer2_0_relu = None
layer2_0_bn2 = getattr(self.layer2, "0").bn2(layer2_0_conv2); layer2_0_conv2 = None
layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(layer1_1_relu_1); layer1_1_relu_1 = None
layer2_0_downsample_1 = getattr(getattr(self.layer2, "0").downsample, "1")(layer2_0_downsample_0); layer2_0_downsample_0 = None
stacked_adder_2 = self.stacked_adder_2(layer2_0_bn2, layer2_0_downsample_1); layer2_0_bn2 = layer2_0_downsample_1 = None
layer2_0_relu_1 = getattr(self.layer2, "0").relu(stacked_adder_2); stacked_adder_2 = None
layer2_1_conv1 = getattr(self.layer2, "1").conv1(layer2_0_relu_1)
layer2_1_bn1 = getattr(self.layer2, "1").bn1(layer2_1_conv1); layer2_1_conv1 = None
layer2_1_relu = getattr(self.layer2, "1").relu(layer2_1_bn1); layer2_1_bn1 = None
layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_relu); layer2_1_relu = None
layer2_1_bn2 = getattr(self.layer2, "1").bn2(layer2_1_conv2); layer2_1_conv2 = None
stacked_adder_3 = self.stacked_adder_3(layer2_1_bn2, layer2_0_relu_1); layer2_1_bn2 = layer2_0_relu_1 = None
layer2_1_relu_1 = getattr(self.layer2, "1").relu(stacked_adder_3); stacked_adder_3 = None
layer3_0_conv1 = getattr(self.layer3, "0").conv1(layer2_1_relu_1)
layer3_0_bn1 = getattr(self.layer3, "0").bn1(layer3_0_conv1); layer3_0_conv1 = None
layer3_0_relu = getattr(self.layer3, "0").relu(layer3_0_bn1); layer3_0_bn1 = None
layer3_0_conv2 = getattr(self.layer3, "0").conv2(layer3_0_relu); layer3_0_relu = None
layer3_0_bn2 = getattr(self.layer3, "0").bn2(layer3_0_conv2); layer3_0_conv2 = None
layer3_0_downsample_0 = getattr(getattr(self.layer3, "0").downsample, "0")(layer2_1_relu_1); layer2_1_relu_1 = None
layer3_0_downsample_1 = getattr(getattr(self.layer3, "0").downsample, "1")(layer3_0_downsample_0); layer3_0_downsample_0 = None
stacked_adder_4 = self.stacked_adder_4(layer3_0_bn2, layer3_0_downsample_1); layer3_0_bn2 = layer3_0_downsample_1 = None
layer3_0_relu_1 = getattr(self.layer3, "0").relu(stacked_adder_4); stacked_adder_4 = None
layer3_1_conv1 = getattr(self.layer3, "1").conv1(layer3_0_relu_1)
layer3_1_bn1 = getattr(self.layer3, "1").bn1(layer3_1_conv1); layer3_1_conv1 = None
layer3_1_relu = getattr(self.layer3, "1").relu(layer3_1_bn1); layer3_1_bn1 = None
layer3_1_conv2 = getattr(self.layer3, "1").conv2(layer3_1_relu); layer3_1_relu = None
layer3_1_bn2 = getattr(self.layer3, "1").bn2(layer3_1_conv2); layer3_1_conv2 = None
stacked_adder_5 = self.stacked_adder_5(layer3_1_bn2, layer3_0_relu_1); layer3_1_bn2 = layer3_0_relu_1 = None
layer3_1_relu_1 = getattr(self.layer3, "1").relu(stacked_adder_5); stacked_adder_5 = None
layer4_0_conv1 = getattr(self.layer4, "0").conv1(layer3_1_relu_1)
layer4_0_bn1 = getattr(self.layer4, "0").bn1(layer4_0_conv1); layer4_0_conv1 = None
layer4_0_relu = getattr(self.layer4, "0").relu(layer4_0_bn1); layer4_0_bn1 = None
layer4_0_conv2 = getattr(self.layer4, "0").conv2(layer4_0_relu); layer4_0_relu = None
layer4_0_bn2 = getattr(self.layer4, "0").bn2(layer4_0_conv2); layer4_0_conv2 = None
layer4_0_downsample_0 = getattr(getattr(self.layer4, "0").downsample, "0")(layer3_1_relu_1); layer3_1_relu_1 = None
layer4_0_downsample_1 = getattr(getattr(self.layer4, "0").downsample, "1")(layer4_0_downsample_0); layer4_0_downsample_0 = None
stacked_adder_6 = self.stacked_adder_6(layer4_0_bn2, layer4_0_downsample_1); layer4_0_bn2 = layer4_0_downsample_1 = None
layer4_0_relu_1 = getattr(self.layer4, "0").relu(stacked_adder_6); stacked_adder_6 = None
layer4_1_conv1 = getattr(self.layer4, "1").conv1(layer4_0_relu_1)
layer4_1_bn1 = getattr(self.layer4, "1").bn1(layer4_1_conv1); layer4_1_conv1 = None
layer4_1_relu = getattr(self.layer4, "1").relu(layer4_1_bn1); layer4_1_bn1 = None
layer4_1_conv2 = getattr(self.layer4, "1").conv2(layer4_1_relu); layer4_1_relu = None
layer4_1_bn2 = getattr(self.layer4, "1").bn2(layer4_1_conv2); layer4_1_conv2 = None
stacked_adder_7 = self.stacked_adder_7(layer4_1_bn2, layer4_0_relu_1); layer4_1_bn2 = layer4_0_relu_1 = None
layer4_1_relu_1 = getattr(self.layer4, "1").relu(stacked_adder_7); stacked_adder_7 = None
avgpool = self.avgpool(layer4_1_relu_1); layer4_1_relu_1 = None
flatten = torch.flatten(avgpool, 1); avgpool = None
fc = self.fc(flatten); flatten = None
return fc
# To see more debug info, please use `graph_module.print_readable()`