Today, I want to add a new function layer in network by torch.fx
So, I borrowed from this approach
this code from: torch.fx — PyTorch master documentation
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
I want to add a new_mul_layer after the add node
such as:
import torch
import torch.nn as nn
import torch.fx as fx
from torch.fx import Tracer as tracer_class
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 3, 1)
def forward(self, x):
out = self.conv(x)
out = torch.add(out, out)
return out
def new_mul_layer(x):
return x * 0
def decompose(model):
decomposition_rules = {}
decomposition_rules[torch.add] = new_mul_layer
graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = fx.proxy.GraphAppendingTracer(graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target == torch.add:
for x in node.args:
if isinstance(x, fx.Node):
proxy_args = fx.Proxy(env[x.name], tracer)
else:
proxy_args = x
output_proxy = decompostition_rules[node.target](proxy_args)
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
if __name__ == "__main__":
data = torch.randn(1, 1, 3, 3)
model = M()
result = model(data)
model = decompose(model)
result = model(data)
However, the following error was reported:
UnboundLocalError : local variable ‘mul’ referenced before assignment.
this ‘mul’ from
def new_mul_layer(x):
return x * 0
I know this error is a problem with global and local variable conversions
But I’m following the official tutorial. Why is this a problem?