Subclass of `fx.GraphModule` generating its own code

Hi all,

Is it possible to neatly encapsulate a torch.fx.GraphModule that generates its own graph? (Based, say, on __init__ parameters.)

The obvious approach appears not to work:

In [8]: class Mod(fx.GraphModule):
   ...:     def __init__(self):
   ...:         super().__init__(self, fx.Graph())
   ...:         def thingy(x):
   ...:             return x* 2.
   ...:         self.graph = fx.symbolic_trace(thingy)
   ...:         self.recompile()
   ...: 

In [9]: m = Mod()

In [10]: m(torch.ones(4))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-56e7b4f5b37c> in <module>
----> 1 m(torch.ones(4))

~/anaconda3/envs/sciml/lib/python3.8/site-packages/torch/fx/graph_module.py in wrapped_call(self, *args, **kwargs)
    306             try:
    307                 sys.excepthook = print_full_traceback
--> 308                 return cls_call(self, *args, **kwargs)
    309             finally:
    310                 sys.excepthook = old_excepthook

~/anaconda3/envs/sciml/lib/python3.8/site-packages/torch/fx/graph_module.py in wrapped_call(self, *args, **kwargs)
    306             try:
    307                 sys.excepthook = print_full_traceback
--> 308                 return cls_call(self, *args, **kwargs)
    309             finally:
    310                 sys.excepthook = old_excepthook

~/anaconda3/envs/sciml/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

TypeError: forward() takes 1 positional argument but 2 were given

Adding an explicit self breaks recompile():

In [11]: class Mod(fx.GraphModule):
    ...:     def __init__(self):
    ...:         super().__init__(self, fx.Graph())
    ...:         def thingy(self, x):
    ...:             return x* 2.
    ...:         self.graph = fx.symbolic_trace(thingy)
    ...:         self.recompile()
    ...: 

In [12]: m = Mod()
Traceback (most recent call last):
...
    self.recompile()

  File "<eval_with_key_9>", line 2
    def forward(self, self, x):
    ^
SyntaxError: duplicate argument 'self' in function definition

Is this sort of architecture — module that generates its own compute graph — supported?

Thanks!

Well nevermind… this was a small mistake but I’m leaving the post in case it happens to anyone else:

In [17]: class Mod(fx.GraphModule):
    ...:     def __init__(self):
    ...:         super().__init__(self, fx.Graph())
    ...:         def thingy(x):
    ...:             return x* 2.
    ...:         self.graph = fx.symbolic_trace(thingy).graph
    ...:         self.recompile()
    ...: 

In [18]: m = Mod()

In [19]: m(torch.ones(4))
Out[19]: tensor([2., 2., 2., 2.])

It works when I set the self.graph to symbolic_trace(...).graph.

This raises a different question: why doesn’t it throw an error when I try to assign a GraphModule to self.graph?

This raises a different question: why doesn’t it throw an error when I try to assign a GraphModule to self.graph?

This is a good question. We rely on Python type annotations to ensure type correctness. Typically in our development we use mypy to check type-correctness of code. However, in this case it seems that graph's setter doesn’t have an annotation for its parameter. I will go ahead and add one as well as an assert on an isinstance check. Thanks for bringing this up!

Makes sense, thanks!

@James_Reed I do have a follow up question on this that is a little less trivial. Since its the same subject as the original thread title, I figured I’d post it here:

How can I generate a graph for forward() of a GraphModule that has access to self? For example:

In [1]: import torch

In [2]: from torch import fx

In [3]: class M(fx.GraphModule):
   ...:     def __init__(self):
   ...:         super().__init__(self, fx.Graph())
   ...:         graph = fx.symbolic_trace(M.fake_forward)
   ...:         self.graph = graph
   ...:         self.recompile()
   ...:         self.register_buffer("w1", torch.ones(3))
   ...:     def fake_forward(self, x):
   ...:         return self.w1 * x
   ...: 

In [4]: m = M()

Gives:

  File "<eval_with_key_1>", line 2
    def forward(self, self, x):
    ^
SyntaxError: duplicate argument 'self' in function definition

Is this possible?

Hi @Linux-cpp-lisp,

The reason this is happening is because if the argument to symbolic_trace is not an nn.Module, we treat it as a free function, regardless of whether it’s a class method or not. I’ll file an issue and see if we can figure out a way to distinguish this case v.s. the truly-free-function case: [FX] Cannot symbolically trace class method · Issue #54785 · pytorch/pytorch · GitHub

1 Like