I am currently trying to make JIT optimizations work on the source code of Tree-LSTM model. The Tree class in the model is a crucial part of it, so I need to make it a custom class type so it can be used by the core methods of the model. That’s when I find a problem:
import torch from typing import List @torch.jit.script class Tree(object): def __init__(self): self.parent = None self.num_children = 0 self.children = torch.jit.annotate(List[Tree], ) # further definitions omitted
When I try to run the code, here is the error:
RuntimeError: Unknown type name Tree: def __init__(self): self.parent = None self.num_children = 0 self.children = torch.jit.annotate(List[Tree], ) ~~~~ <--- HERE
So the question is: Tree is basically a recursive structure, the children of a tree node is a list of tree nodes. Therefore for the children variable of a Tree class, I need to define an empty list of Tree class type. But since the Tree class type definition is still halfway, the interpreter cannot recognize Tree type.
I am wondering is there any method that I can solve this problem, or there is actually no support for custom classes like the Tree above in current version of PyTorch and I should try other ways.
It would be so nice if someone can give me a hand. Thanks a lot!