I am working with pytorch lightning, and want to do some fx transformations on the code of the model.
The issue I face is that when tracing a model, you loose the original model instance and get a
GraphModule instance. This has two effects:
- The resulting class loses all methods from of the pytorch lightning model (e.g. no
- Any state in that class that is not part of the graph is lost.
The above means that the traced model cannot be used in lightning.
It seemed to me that the way to solve this is by Sub-classing GraphModule, which is a
thing mentioned in a few places, but I couldn’t make it work.
- All the copy/serialization do not work properly.
__new__implementation interact badly with inheritance.
#63883 comes up very fast as you get
a class that is not
Of course it is possible to rewrite all the copy methods,
but this will duplicate most of
GraphModule code, and moreover - be very brittle to
implementation changes inside
- Did anyone manage to properly subclass
- Is there any other way to trace pytorch lightning model, or any
other model with non graph state for that matter?