From the FX documentation, I learned that there are several node types:
- placeholder
- getattr
- output
- call_method
- call_function
- call_module
I have difficulty to tell the last three concepts apart. Could someone explain them using more examples?
From the FX documentation, I learned that there are several node types:
I have difficulty to tell the last three concepts apart. Could someone explain them using more examples?
Here is an example:
import torch
import torch.fx as fx
import torch.nn as nn
def myadd(x, y):
return x + y
class Adder(nn.Module):
def forward(self, x, y):
return x + y
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.adder = Adder()
def forward(self, x, y):
return x + y, myadd(x, y), torch.add(x, y), x.add(y), self.adder(x, y)
fx.symbolic_trace(M()).graph.print_tabular()
#### output
opcode name target args kwargs
------------- ------ --------------------------------------------------- ------------------------------------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
call_function add_1 <built-in function add> (x, y) {}
call_function add_2 <built-in method add of type object at 0x10310d390> (x, y) {}
call_method add_3 add (x, y) {}
call_function add_4 <built-in function add> (x, y) {}
output output output ((add, add_1, add_2, add_3, add_4),) {}
Why the Adder
module not traced as call_module
?
What is the difference between Tensor.add
and myadd
? Because Tensor.add
is bound with a Tensor object and thus it is not free?
call_function
is e.g. your myadd
, but the FX symbolic tracer will trace through such user-defined functions unless you FX wrap
them.call_module
is e.g. your Adder
module, but the FX symbolic tracer will trace through such modules unless you tell the symbolic tracer not to via is_leaf_module. See the default logic here for is_leaf_module
.call_method
is when you have some method called on a Tensor, e.g. your example x.add(y)
. The IR here is just representative of how the original python/pytorch program represented the add; you could likely simply replace all call_method add
to call_function torch.add
with no difference in output from execution of your module.