Replacing math operator with custom layer inside existing model

Hi, I am trying to replace a given set of math operators automatically using the FX tracer. For instance, I would like to replace the + operator used in the residual blocks of a resnet model by a custom CustomAdd layer.

EDIT: I think I found a solution by adding a CustomAdd layer on-the-fly to the GraphModule returned by symbolic_trace, see code below.

import torch
from torchvision.models.resnet import resnet18, ResNet18_Weights
from torch.fx import symbolic_trace
import torch.nn as nn


class CustomAdd(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        """Replacement operator"""
        return torch.sum(torch.stack([a, b], dim=0), dim=0)


def search_and_replace(module: nn.Module) -> nn.Module:
    traced = symbolic_trace(module)
    op_idx = 0 # Count the number of replacements

    for n in traced.graph.nodes:
        if n.target.__repr__() == "<built-in function add>":         
            # This is a ugly way of detecting + operator, I wonder if something better can be done
            with traced.graph.inserting_after(n):
                # Add layer to the traced module on the fly
                module.add_module(f"custom_add_{op_idx}", CustomAdd())
                new_node = traced.graph.call_module(f"custom_add_{op_idx}", n.args, n.kwargs)
                n.replace_all_uses_with(new_node)
            # Remove the old node from the graph
            traced.graph.erase_node(n)
            # Increase counter to add one layer per addition
            op_idx +=1
    traced.recompile()
    return traced


model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
print(search_and_replace(model))

Output:

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Module(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Module(
    (0): Module(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Module(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Module(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Module(
    (0): Module(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Module(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Module(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Module(
    (0): Module(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Module(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Module(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
  (stacked_adder_0): CustomAdd()
  (stacked_adder_1): CustomAdd()
  (stacked_adder_2): CustomAdd()
  (stacked_adder_3): CustomAdd()
  (stacked_adder_4): CustomAdd()
  (stacked_adder_5): CustomAdd()
  (stacked_adder_6): CustomAdd()
  (stacked_adder_7): CustomAdd()
)



def forward(self, x : torch.Tensor) -> torch.Tensor:
    conv1 = self.conv1(x);  x = None
    bn1 = self.bn1(conv1);  conv1 = None
    relu = self.relu(bn1);  bn1 = None
    maxpool = self.maxpool(relu);  relu = None
    layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
    layer1_0_bn1 = getattr(self.layer1, "0").bn1(layer1_0_conv1);  layer1_0_conv1 = None
    layer1_0_relu = getattr(self.layer1, "0").relu(layer1_0_bn1);  layer1_0_bn1 = None
    layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_relu);  layer1_0_relu = None
    layer1_0_bn2 = getattr(self.layer1, "0").bn2(layer1_0_conv2);  layer1_0_conv2 = None
    stacked_adder_0 = self.stacked_adder_0(layer1_0_bn2, maxpool);  layer1_0_bn2 = maxpool = None
    layer1_0_relu_1 = getattr(self.layer1, "0").relu(stacked_adder_0);  stacked_adder_0 = None
    layer1_1_conv1 = getattr(self.layer1, "1").conv1(layer1_0_relu_1)
    layer1_1_bn1 = getattr(self.layer1, "1").bn1(layer1_1_conv1);  layer1_1_conv1 = None
    layer1_1_relu = getattr(self.layer1, "1").relu(layer1_1_bn1);  layer1_1_bn1 = None
    layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_relu);  layer1_1_relu = None
    layer1_1_bn2 = getattr(self.layer1, "1").bn2(layer1_1_conv2);  layer1_1_conv2 = None
    stacked_adder_1 = self.stacked_adder_1(layer1_1_bn2, layer1_0_relu_1);  layer1_1_bn2 = layer1_0_relu_1 = None
    layer1_1_relu_1 = getattr(self.layer1, "1").relu(stacked_adder_1);  stacked_adder_1 = None
    layer2_0_conv1 = getattr(self.layer2, "0").conv1(layer1_1_relu_1)
    layer2_0_bn1 = getattr(self.layer2, "0").bn1(layer2_0_conv1);  layer2_0_conv1 = None
    layer2_0_relu = getattr(self.layer2, "0").relu(layer2_0_bn1);  layer2_0_bn1 = None
    layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_relu);  layer2_0_relu = None
    layer2_0_bn2 = getattr(self.layer2, "0").bn2(layer2_0_conv2);  layer2_0_conv2 = None
    layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(layer1_1_relu_1);  layer1_1_relu_1 = None
    layer2_0_downsample_1 = getattr(getattr(self.layer2, "0").downsample, "1")(layer2_0_downsample_0);  layer2_0_downsample_0 = None
    stacked_adder_2 = self.stacked_adder_2(layer2_0_bn2, layer2_0_downsample_1);  layer2_0_bn2 = layer2_0_downsample_1 = None
    layer2_0_relu_1 = getattr(self.layer2, "0").relu(stacked_adder_2);  stacked_adder_2 = None
    layer2_1_conv1 = getattr(self.layer2, "1").conv1(layer2_0_relu_1)
    layer2_1_bn1 = getattr(self.layer2, "1").bn1(layer2_1_conv1);  layer2_1_conv1 = None
    layer2_1_relu = getattr(self.layer2, "1").relu(layer2_1_bn1);  layer2_1_bn1 = None
    layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_relu);  layer2_1_relu = None
    layer2_1_bn2 = getattr(self.layer2, "1").bn2(layer2_1_conv2);  layer2_1_conv2 = None
    stacked_adder_3 = self.stacked_adder_3(layer2_1_bn2, layer2_0_relu_1);  layer2_1_bn2 = layer2_0_relu_1 = None
    layer2_1_relu_1 = getattr(self.layer2, "1").relu(stacked_adder_3);  stacked_adder_3 = None
    layer3_0_conv1 = getattr(self.layer3, "0").conv1(layer2_1_relu_1)
    layer3_0_bn1 = getattr(self.layer3, "0").bn1(layer3_0_conv1);  layer3_0_conv1 = None
    layer3_0_relu = getattr(self.layer3, "0").relu(layer3_0_bn1);  layer3_0_bn1 = None
    layer3_0_conv2 = getattr(self.layer3, "0").conv2(layer3_0_relu);  layer3_0_relu = None
    layer3_0_bn2 = getattr(self.layer3, "0").bn2(layer3_0_conv2);  layer3_0_conv2 = None
    layer3_0_downsample_0 = getattr(getattr(self.layer3, "0").downsample, "0")(layer2_1_relu_1);  layer2_1_relu_1 = None
    layer3_0_downsample_1 = getattr(getattr(self.layer3, "0").downsample, "1")(layer3_0_downsample_0);  layer3_0_downsample_0 = None
    stacked_adder_4 = self.stacked_adder_4(layer3_0_bn2, layer3_0_downsample_1);  layer3_0_bn2 = layer3_0_downsample_1 = None
    layer3_0_relu_1 = getattr(self.layer3, "0").relu(stacked_adder_4);  stacked_adder_4 = None
    layer3_1_conv1 = getattr(self.layer3, "1").conv1(layer3_0_relu_1)
    layer3_1_bn1 = getattr(self.layer3, "1").bn1(layer3_1_conv1);  layer3_1_conv1 = None
    layer3_1_relu = getattr(self.layer3, "1").relu(layer3_1_bn1);  layer3_1_bn1 = None
    layer3_1_conv2 = getattr(self.layer3, "1").conv2(layer3_1_relu);  layer3_1_relu = None
    layer3_1_bn2 = getattr(self.layer3, "1").bn2(layer3_1_conv2);  layer3_1_conv2 = None
    stacked_adder_5 = self.stacked_adder_5(layer3_1_bn2, layer3_0_relu_1);  layer3_1_bn2 = layer3_0_relu_1 = None
    layer3_1_relu_1 = getattr(self.layer3, "1").relu(stacked_adder_5);  stacked_adder_5 = None
    layer4_0_conv1 = getattr(self.layer4, "0").conv1(layer3_1_relu_1)
    layer4_0_bn1 = getattr(self.layer4, "0").bn1(layer4_0_conv1);  layer4_0_conv1 = None
    layer4_0_relu = getattr(self.layer4, "0").relu(layer4_0_bn1);  layer4_0_bn1 = None
    layer4_0_conv2 = getattr(self.layer4, "0").conv2(layer4_0_relu);  layer4_0_relu = None
    layer4_0_bn2 = getattr(self.layer4, "0").bn2(layer4_0_conv2);  layer4_0_conv2 = None
    layer4_0_downsample_0 = getattr(getattr(self.layer4, "0").downsample, "0")(layer3_1_relu_1);  layer3_1_relu_1 = None
    layer4_0_downsample_1 = getattr(getattr(self.layer4, "0").downsample, "1")(layer4_0_downsample_0);  layer4_0_downsample_0 = None
    stacked_adder_6 = self.stacked_adder_6(layer4_0_bn2, layer4_0_downsample_1);  layer4_0_bn2 = layer4_0_downsample_1 = None
    layer4_0_relu_1 = getattr(self.layer4, "0").relu(stacked_adder_6);  stacked_adder_6 = None
    layer4_1_conv1 = getattr(self.layer4, "1").conv1(layer4_0_relu_1)
    layer4_1_bn1 = getattr(self.layer4, "1").bn1(layer4_1_conv1);  layer4_1_conv1 = None
    layer4_1_relu = getattr(self.layer4, "1").relu(layer4_1_bn1);  layer4_1_bn1 = None
    layer4_1_conv2 = getattr(self.layer4, "1").conv2(layer4_1_relu);  layer4_1_relu = None
    layer4_1_bn2 = getattr(self.layer4, "1").bn2(layer4_1_conv2);  layer4_1_conv2 = None
    stacked_adder_7 = self.stacked_adder_7(layer4_1_bn2, layer4_0_relu_1);  layer4_1_bn2 = layer4_0_relu_1 = None
    layer4_1_relu_1 = getattr(self.layer4, "1").relu(stacked_adder_7);  stacked_adder_7 = None
    avgpool = self.avgpool(layer4_1_relu_1);  layer4_1_relu_1 = None
    flatten = torch.flatten(avgpool, 1);  avgpool = None
    fc = self.fc(flatten);  flatten = None
    return fc
    
# To see more debug info, please use `graph_module.print_readable()`