I am working a custom backend for torch dynamo. If a recompilation is re-triggered for some reason (for example input shape changes), the backend should process a fx.Node the same way like it was in the first pass (e.g. some sort of caching).
However, I struggle to match the nodes across recompilations:
- the node name is not unique across if there’s graph breaks in the model. For example, if the model is broken into two graphs, and both contain an
add
, there will be in model two nodes namedadd
. - I tried using the number of graph breaks as an additional identifier, but noticed that in some models the re-compiled graph has a different number of graph breaks then the originally traced one.
- the
id
of the graph or nodes are always different, so this seems to be newly created torch.fx.Node and torch.fx.Graph instances.
Last but not least:
- I used the code snipped stored in the meta data to generate a hash and use that as an identifier together with the node name. However, I noticed that in the recompiled graph, the code can slightly differ. For example, one says “forward”, and during the recompilation it says
"<resume in forward>"
I probably could work around this, but wanted to raise this here to see if a better solution can be found.
Thank you