Fx tracing pytorch lightning models

I am working with pytorch lightning, and want to do some fx transformations on the code of the model.

The issue I face is that when tracing a model, you loose the original model instance and get a
GraphModule instance. This has two effects:

  1. The resulting class loses all methods from of the pytorch lightning model (e.g. no train_step).
  2. Any state in that class that is not part of the graph is lost.

The above means that the traced model cannot be used in lightning.
It seemed to me that the way to solve this is by Sub-classing GraphModule, which is a
thing mentioned in a few places, but I couldn’t make it work.

  1. All the copy/serialization do not work properly.
  2. The __call__ wrapping and __new__ implementation interact badly with inheritance.
    #63883 comes up very fast as you get
    a class that is not GraphModuleImpl.

Of course it is possible to rewrite all the copy methods, __new__ and __call__
but this will duplicate most of GraphModule code, and moreover - be very brittle to
implementation changes inside GraphModule

  • Did anyone manage to properly subclass GraphModule?
  • Is there any other way to trace pytorch lightning model, or any
    other model with non graph state for that matter?