Hi there!
I am currently trying to make JIT optimizations work on the source code of Tree-LSTM model. The Tree class in the model is a crucial part of it, so I need to make it a custom class type so it can be used by the core methods of the model. That’s when I find a problem:
import torch
from typing import List
@torch.jit.script
class Tree(object):
def __init__(self):
self.parent = None
self.num_children = 0
self.children = torch.jit.annotate(List[Tree], [])
# further definitions omitted
When I try to run the code, here is the error:
RuntimeError:
Unknown type name Tree:
def __init__(self):
self.parent = None
self.num_children = 0
self.children = torch.jit.annotate(List[Tree], [])
~~~~ <--- HERE
So the question is: Tree is basically a recursive structure, the children of a tree node is a list of tree nodes. Therefore for the children variable of a Tree class, I need to define an empty list of Tree class type. But since the Tree class type definition is still halfway, the interpreter cannot recognize Tree type.
I am wondering is there any method that I can solve this problem, or there is actually no support for custom classes like the Tree above in current version of PyTorch and I should try other ways.
It would be so nice if someone can give me a hand. Thanks a lot!