Torch.fx : UnboundLocalError : local variable 'mul' referenced before assignment

Today, I want to add a new function layer in network by torch.fx

So, I borrowed from this approach

this code from: torch.fx — PyTorch master documentation

# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

I want to add a new_mul_layer after the add node

such as:

import torch
import torch.nn as nn
import torch.fx as fx
from torch.fx import Tracer as tracer_class

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 3, 1)
    
    def forward(self, x):
        out = self.conv(x)
        out = torch.add(out, out)
        return out

def new_mul_layer(x):
    return x * 0

def decompose(model):
    decomposition_rules = {}
    decomposition_rules[torch.add] = new_mul_layer
    graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = fx.proxy.GraphAppendingTracer(graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target == torch.add:
            for x in node.args:
                if isinstance(x, fx.Node):
                    proxy_args = fx.Proxy(env[x.name], tracer)
                else:
                    proxy_args = x
            output_proxy  = decompostition_rules[node.target](proxy_args)
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

if __name__ == "__main__":
    data = torch.randn(1, 1, 3, 3)
    model = M()
    result = model(data)
    model = decompose(model)
    result = model(data)

However, the following error was reported:

UnboundLocalError : local variable ‘mul’ referenced before assignment.

this ‘mul’ from

def new_mul_layer(x):
    return x * 0

I know this error is a problem with global and local variable conversions

But I’m following the official tutorial. Why is this a problem?