I’m looking to track output statistics of operators within networks. Take the following network:
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 3)
self.fc2 = nn.Linear(3, 3)
self.fc3 = nn.Linear(3, 3)
def forward(self, x):
x = self.fc1(x)
y = self.fc2(x)
print("beginning addition")
z = x + y
return self.fc3(x)
I want to be able to track the range of values z
takes on over time. I could always instantiate the model with a custom layer like self.add = Add()
, but my goal is for other people to interact with what I’m working on, and forcing people to declare a new module for these operations isn’t something I’d like to have happen. I’d much rather have a way to automatically convert instances of x+y
into nn.Module
(like below) and then use forward hooks under the hood.
import torch.nn as nn
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y
It also seems pretty straightforward to use the symbolic trace feature to generate and modify the network graph (also included below), but doing so overwrites my original forward pass. Therefore, any print statements I’ve added get removed, which I can’t really allow either.
I’ve put together a minimal working example that includes a symbolic trace, however is there any other better (general) approach that retains the original forward function, just with all instances of “+” replaced with Add()? This way, users don’t need to modify the model instantiations and print statements remain?
import operator
import torch.fx as fx
def replace_add(model):
traced = fx.symbolic_trace(model)
num_adds = 0
for node in traced.graph.nodes:
if node.op == "call_function" and node.target == operator.add:
module_instance = Add()
module_name = f"add_{num_adds}"
traced.add_module(module_name, module_instance)
with traced.graph.inserting_after(node):
add_node = traced.graph.call_module(module_name, args=node.args)
node.replace_all_uses_with(add_node)
traced.graph.erase_node(node)
num_adds += 1
traced.recompile()
return traced # no longer includes print statements
x = torch.randn(3)
original_net = Model()
original_out = original_net(x)
modified_net = replace_add(original_net)
print("we must see 'beginning addition' after this print")
modified_out = modified_net(x)