“self.graph.owning_module not set for purity check” error when trying to remove nodes from torch.fx.graph
I have tried to remove nodes from torch.fx.graph but it does not work properly.
Here is the code example. This example uses a small model of 2 conv layers and finally sum outputs of these two layers. I changed conv2 and add layers into rogue nodes, and then let the output node connect directly to conv1.
The purity check error stops the program when I try to cleanup rogue nodes using eliminate_dead_code()
.
torch.Size([64, 8, 28, 28])
opcode name target args kwargs
------------- ------ ----------------------- -------------- --------
placeholder x x () {}
call_module conv1 conv1 (x,) {}
call_module conv2 conv2 (x,) {}
call_function add <built-in function add> (conv1, conv2) {}
output output output (add,) {}
==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
conv1 <class 'torch.fx.node.Node'>
conv2 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
x <class 'torch.fx.node.Node'>
Successors:
add <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
x <class 'torch.fx.node.Node'>
Successors:
add <class 'torch.fx.node.Node'>
----------------------------------------------
node: add
Predecessors:
conv1 <class 'torch.fx.node.Node'>
conv2 <class 'torch.fx.node.Node'>
Successors:
output <class 'torch.fx.node.Node'>
----------------------------------------------
node: output
Predecessors:
add <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------
opcode name target args kwargs
------------- ------ ----------------------- -------- --------
placeholder x x () {}
call_module conv1 conv1 (x,) {}
call_module conv2 conv2 () {}
call_function add <built-in function add> () {}
output output output (conv1,) {}
==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
conv1 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
x <class 'torch.fx.node.Node'>
Successors:
output <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
Successors:
----------------------------------------------
node: add
Predecessors:
Successors:
----------------------------------------------
node: output
Predecessors:
conv1 <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------
Traceback (most recent call last):
File "github_question.py", line 77, in <module>
graph.eliminate_dead_code()
File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/graph.py", line 1170, in eliminate_dead_code
if not node.is_impure() and len(node.users) == 0:
File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/node.py", line 511, in is_impure
), "self.graph.owning_module not set for purity check"
AssertionError: self.graph.owning_module not set for purity check
Reply
created
21h
last reply
1m
1
reply
8
views
1
user
jfix
Jordan Fix
6h
You need to set graph.owning_module on the graph module. This is because for call_module nodes, we need to fetch the original module object instance to check if it is set as pure (i.e. if it has no side effects) before removing it.
BTW, this is done usually by default when calling torch.fx.symbolic_trace(). Is there a reason you don’t want to call that helper instead?
Solution
Reply
shengchunnan
1m
I updated the code a little and now it works.
Many thanks for your kind suggestion.
import torch
import torch.nn as nn
import torch.fx as fx
class Network(nn.Module):
def __init__(self, num_classes=10):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2, bias=False)
self.conv2 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, x):
out1 = self.conv1(x)
out2 = self.conv2(x)
return out1 + out2
def delete_conv2(graph):
preserved_conv_node = None
output_node = None
for node in graph.nodes:
if 'call_module' == node.op:
if 'conv1' in node.name:
preserved_conv_node = node
# Remove conv2 operation from graph
elif 'conv2' in node.name:
node.args = ()
# Remove the add operation from graph
elif 'call_function' == node.op and 'add' == node.name:
node.args = ()
elif 'output' == node.op:
output_node = node
if preserved_conv_node and output_node:
output_node.args = (preserved_conv_node,)
def print_pre_and_suc(graph):
print('\n==== Predecessors and successors of each node: ====')
for node in graph.nodes:
print('node: {}'.format(node.name))
print('Predecessors:')
for item in node.args:
print('\t', item, type(item))
print('Successors:')
for key, val in node.users.items():
print('\t', key, type(key))
print('----------------------------------------------')
print()
if __name__ == "__main__":
model = Network()
y = model(torch.randn(64, 1, 28, 28))
print(y.size())
fx_model = fx.symbolic_trace(model)
graph = fx_model.graph
graph.print_tabular()
print_pre_and_suc(graph)
delete_conv2(graph)
graph.print_tabular()
print_pre_and_suc(graph)
graph.eliminate_dead_code()
fx_model.recompile()
print(fx_model.code)
y = fx_model(torch.randn(64, 1, 28, 28))
print(y.size())
- You can see from the log that conv2 and add are successfully changed into rogue nodes (no parents, no children), but
graph.eliminate_dead_code()
fails to run properly due to assertion ofself.graph.owning_module
. - I am wondering, why do we need this kind of assertion? And how can I make sure that owning_module is not none.
torch.Size([64, 8, 28, 28])
opcode name target args kwargs
------------- ------ ----------------------- -------------- --------
placeholder x x () {}
call_module conv1 conv1 (x,) {}
call_module conv2 conv2 (x,) {}
call_function add <built-in function add> (conv1, conv2) {}
output output output (add,) {}
==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
conv1 <class 'torch.fx.node.Node'>
conv2 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
x <class 'torch.fx.node.Node'>
Successors:
add <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
x <class 'torch.fx.node.Node'>
Successors:
add <class 'torch.fx.node.Node'>
----------------------------------------------
node: add
Predecessors:
conv1 <class 'torch.fx.node.Node'>
conv2 <class 'torch.fx.node.Node'>
Successors:
output <class 'torch.fx.node.Node'>
----------------------------------------------
node: output
Predecessors:
add <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------
opcode name target args kwargs
------------- ------ ----------------------- -------- --------
placeholder x x () {}
call_module conv1 conv1 (x,) {}
call_module conv2 conv2 () {}
call_function add <built-in function add> () {}
output output output (conv1,) {}
==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
conv1 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
x <class 'torch.fx.node.Node'>
Successors:
output <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
Successors:
----------------------------------------------
node: add
Predecessors:
Successors:
----------------------------------------------
node: output
Predecessors:
conv1 <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------
Traceback (most recent call last):
File "github_question.py", line 77, in <module>
graph.eliminate_dead_code()
File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/graph.py", line 1170, in eliminate_dead_code
if not node.is_impure() and len(node.users) == 0:
File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/node.py", line 511, in is_impure
), "self.graph.owning_module not set for purity check"
AssertionError: self.graph.owning_module not set for purity check