"self.graph.owning_module not set for purity check" error when trying to remove a node from torch.fx.graph

:question: “self.graph.owning_module not set for purity check” error when trying to remove nodes from torch.fx.graph

I have tried to remove nodes from torch.fx.graph but it does not work properly.

Here is the code example. This example uses a small model of 2 conv layers and finally sum outputs of these two layers. I changed conv2 and add layers into rogue nodes, and then let the output node connect directly to conv1.
The purity check error stops the program when I try to cleanup rogue nodes using eliminate_dead_code().

torch.Size([64, 8, 28, 28])
opcode         name    target                   args            kwargs
-------------  ------  -----------------------  --------------  --------
placeholder    x       x                        ()              {}
call_module    conv1   conv1                    (x,)            {}
call_module    conv2   conv2                    (x,)            {}
call_function  add     <built-in function add>  (conv1, conv2)  {}
output         output  output                   (add,)          {}

==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
	 conv1 <class 'torch.fx.node.Node'>
	 conv2 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 add <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 add <class 'torch.fx.node.Node'>
----------------------------------------------
node: add
Predecessors:
	 conv1 <class 'torch.fx.node.Node'>
	 conv2 <class 'torch.fx.node.Node'>
Successors:
	 output <class 'torch.fx.node.Node'>
----------------------------------------------
node: output
Predecessors:
	 add <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------

opcode         name    target                   args      kwargs
-------------  ------  -----------------------  --------  --------
placeholder    x       x                        ()        {}
call_module    conv1   conv1                    (x,)      {}
call_module    conv2   conv2                    ()        {}
call_function  add     <built-in function add>  ()        {}
output         output  output                   (conv1,)  {}

==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
	 conv1 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 output <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
Successors:
----------------------------------------------
node: add
Predecessors:
Successors:
----------------------------------------------
node: output
Predecessors:
	 conv1 <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------

Traceback (most recent call last):
  File "github_question.py", line 77, in <module>
    graph.eliminate_dead_code()
  File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/graph.py", line 1170, in eliminate_dead_code
    if not node.is_impure() and len(node.users) == 0:
  File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/node.py", line 511, in is_impure
    ), "self.graph.owning_module not set for purity check"
AssertionError: self.graph.owning_module not set for purity check
Reply

created
21h
last reply
1m
1
reply
8
views
1
user

jfix
Jordan Fix
6h
You need to set graph.owning_module on the graph module. This is because for call_module nodes, we need to fetch the original module object instance to check if it is set as pure (i.e. if it has no side effects) before removing it.

BTW, this is done usually by default when calling torch.fx.symbolic_trace(). Is there a reason you don’t want to call that helper instead?

Solution
Reply

shengchunnan
1m
I updated the code a little and now it works.
Many thanks for your kind suggestion.

import torch
import torch.nn as nn
import torch.fx as fx

class Network(nn.Module):
    def __init__(self, num_classes=10):
        super(Network, self).__init__()

        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2, bias=False)
        self.conv2 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):

        out1 = self.conv1(x)
        out2 = self.conv2(x)

        return out1 + out2


def delete_conv2(graph):

    preserved_conv_node = None
    output_node = None

    for node in graph.nodes:
        if 'call_module' == node.op:
            if 'conv1' in node.name:
                preserved_conv_node = node
            # Remove conv2 operation from graph
            elif 'conv2' in node.name:
                node.args = ()
        # Remove the add operation from graph
        elif 'call_function' == node.op and 'add' == node.name:
            node.args = ()
        elif 'output' == node.op:
            output_node = node

    if preserved_conv_node and output_node:
        output_node.args = (preserved_conv_node,)


def print_pre_and_suc(graph):

    print('\n==== Predecessors and successors of each node: ====')
    for node in graph.nodes:
        print('node: {}'.format(node.name))

        print('Predecessors:')
        for item in node.args:
            print('\t', item, type(item))

        print('Successors:')
        for key, val in node.users.items():
            print('\t', key, type(key))
        print('----------------------------------------------')
    print()
    

if __name__ == "__main__":

    model = Network()

    y = model(torch.randn(64, 1, 28, 28))
    print(y.size())

    fx_model = fx.symbolic_trace(model)
    graph = fx_model.graph
    
    graph.print_tabular()
    print_pre_and_suc(graph)

    delete_conv2(graph)

    graph.print_tabular()
    print_pre_and_suc(graph)

    graph.eliminate_dead_code()
    fx_model.recompile()
    print(fx_model.code)

    y = fx_model(torch.randn(64, 1, 28, 28))
    print(y.size())
  • You can see from the log that conv2 and add are successfully changed into rogue nodes (no parents, no children), but graph.eliminate_dead_code() fails to run properly due to assertion of self.graph.owning_module.
  • I am wondering, why do we need this kind of assertion? And how can I make sure that owning_module is not none.
torch.Size([64, 8, 28, 28])
opcode         name    target                   args            kwargs
-------------  ------  -----------------------  --------------  --------
placeholder    x       x                        ()              {}
call_module    conv1   conv1                    (x,)            {}
call_module    conv2   conv2                    (x,)            {}
call_function  add     <built-in function add>  (conv1, conv2)  {}
output         output  output                   (add,)          {}

==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
	 conv1 <class 'torch.fx.node.Node'>
	 conv2 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 add <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 add <class 'torch.fx.node.Node'>
----------------------------------------------
node: add
Predecessors:
	 conv1 <class 'torch.fx.node.Node'>
	 conv2 <class 'torch.fx.node.Node'>
Successors:
	 output <class 'torch.fx.node.Node'>
----------------------------------------------
node: output
Predecessors:
	 add <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------

opcode         name    target                   args      kwargs
-------------  ------  -----------------------  --------  --------
placeholder    x       x                        ()        {}
call_module    conv1   conv1                    (x,)      {}
call_module    conv2   conv2                    ()        {}
call_function  add     <built-in function add>  ()        {}
output         output  output                   (conv1,)  {}

==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
	 conv1 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 output <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
Successors:
----------------------------------------------
node: add
Predecessors:
Successors:
----------------------------------------------
node: output
Predecessors:
	 conv1 <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------

Traceback (most recent call last):
  File "github_question.py", line 77, in <module>
    graph.eliminate_dead_code()
  File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/graph.py", line 1170, in eliminate_dead_code
    if not node.is_impure() and len(node.users) == 0:
  File "/home/shengcn/anaconda3/envs/torch110/lib/python3.7/site-packages/torch/fx/node.py", line 511, in is_impure
    ), "self.graph.owning_module not set for purity check"
AssertionError: self.graph.owning_module not set for purity check

You need to set graph.owning_module on the graph module. This is because for call_module nodes, we need to fetch the original module object instance to check if it is set as pure (i.e. if it has no side effects) before removing it.

BTW, this is done usually by default when calling torch.fx.symbolic_trace(). Is there a reason you don’t want to call that helper instead?

I updated the code a little and now it works.
Many thanks for your kind suggestion.

import torch
import torch.nn as nn
import torch.fx as fx

class Network(nn.Module):
    def __init__(self, num_classes=10):
        super(Network, self).__init__()

        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2, bias=False)
        self.conv2 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):

        out1 = self.conv1(x)
        out2 = self.conv2(x)

        return out1 + out2


def delete_conv2(graph):

    preserved_conv_node = None
    output_node = None

    for node in graph.nodes:
        if 'call_module' == node.op:
            if 'conv1' in node.name:
                preserved_conv_node = node
            # Remove conv2 operation from graph
            elif 'conv2' in node.name:
                node.args = ()
        # Remove the add operation from graph
        elif 'call_function' == node.op and 'add' == node.name:
            node.args = ()
        elif 'output' == node.op:
            output_node = node

    if preserved_conv_node and output_node:
        output_node.args = (preserved_conv_node,)


def print_pre_and_suc(graph):

    print('\n==== Predecessors and successors of each node: ====')
    for node in graph.nodes:
        print('node: {}'.format(node.name))

        print('Predecessors:')
        for item in node.args:
            print('\t', item, type(item))

        print('Successors:')
        for key, val in node.users.items():
            print('\t', key, type(key))
        print('----------------------------------------------')
    print()
    

if __name__ == "__main__":

    model = Network()

    y = model(torch.randn(64, 1, 28, 28))
    print(y.size())

    fx_model = fx.symbolic_trace(model)
    graph = fx_model.graph
    
    graph.print_tabular()
    print_pre_and_suc(graph)

    delete_conv2(graph)

    graph.print_tabular()
    print_pre_and_suc(graph)

    graph.eliminate_dead_code()
    fx_model.recompile()
    print(fx_model.code)

    y = fx_model(torch.randn(64, 1, 28, 28))
    print(y.size())
torch.Size([64, 8, 28, 28])
opcode         name    target                   args            kwargs
-------------  ------  -----------------------  --------------  --------
placeholder    x       x                        ()              {}
call_module    conv1   conv1                    (x,)            {}
call_module    conv2   conv2                    (x,)            {}
call_function  add     <built-in function add>  (conv1, conv2)  {}
output         output  output                   (add,)          {}

==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
	 conv1 <class 'torch.fx.node.Node'>
	 conv2 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 add <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 add <class 'torch.fx.node.Node'>
----------------------------------------------
node: add
Predecessors:
	 conv1 <class 'torch.fx.node.Node'>
	 conv2 <class 'torch.fx.node.Node'>
Successors:
	 output <class 'torch.fx.node.Node'>
----------------------------------------------
node: output
Predecessors:
	 add <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------

opcode         name    target                   args      kwargs
-------------  ------  -----------------------  --------  --------
placeholder    x       x                        ()        {}
call_module    conv1   conv1                    (x,)      {}
call_module    conv2   conv2                    ()        {}
call_function  add     <built-in function add>  ()        {}
output         output  output                   (conv1,)  {}

==== Predecessors and successors of each node: ====
node: x
Predecessors:
Successors:
	 conv1 <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv1
Predecessors:
	 x <class 'torch.fx.node.Node'>
Successors:
	 output <class 'torch.fx.node.Node'>
----------------------------------------------
node: conv2
Predecessors:
Successors:
----------------------------------------------
node: add
Predecessors:
Successors:
----------------------------------------------
node: output
Predecessors:
	 conv1 <class 'torch.fx.node.Node'>
Successors:
----------------------------------------------




def forward(self, x):
    conv1 = self.conv1(x);  x = None
    return conv1
    
torch.Size([64, 8, 28, 28])