Is there support of recursive custom classes like Tree in Tree-LSTM?

Hi there!
I am currently trying to make JIT optimizations work on the source code of Tree-LSTM model. The Tree class in the model is a crucial part of it, so I need to make it a custom class type so it can be used by the core methods of the model. That’s when I find a problem:

import torch
from typing import List

@torch.jit.script
class Tree(object):
    def __init__(self):
        self.parent = None
        self.num_children = 0
        self.children = torch.jit.annotate(List[Tree], [])

    # further definitions omitted

When I try to run the code, here is the error:

RuntimeError: 
Unknown type name Tree:
def __init__(self):
    self.parent = None
    self.num_children = 0
    self.children = torch.jit.annotate(List[Tree], [])
                                            ~~~~ <--- HERE

So the question is: Tree is basically a recursive structure, the children of a tree node is a list of tree nodes. Therefore for the children variable of a Tree class, I need to define an empty list of Tree class type. But since the Tree class type definition is still halfway, the interpreter cannot recognize Tree type.

I am wondering is there any method that I can solve this problem, or there is actually no support for custom classes like the Tree above in current version of PyTorch and I should try other ways.

It would be so nice if someone can give me a hand. Thanks a lot!

I believe this is fixed in master (your class compiles for me). Could you try your code on our nightly release and see if it works?

The code successfully runs after I’ve switched to nightly release, cool! Thank you!

But there is bad news that new errors arise when the interpreter tries to compile other methods. For example:

@torch.jit.script
class Tree(object):
    def __init__(self):
        self.parent = torch.jit.annotate(Optional[Tree], None)

    def add_child(self, child):
        # type: (Tree) -> None
        child.parent = torch.jit.annotate(Optional[Tree], self)

When run to the last line, an error arises:

RuntimeError: 
expected an expression of type Optional[__torch__.Tree] but found __torch__.Tree:
at /Users/fere/repos/treelstm.pytorch/treelstm/tree.py:25:43
    def add_child(self, child):
        # type: (Tree) -> None
        child.parent = torch.jit.annotate(Optional[Tree], self)
                                          ~~~~~~~~~~~~~~~~~~~~ <--- HERE

The interpreter regards self as Tree instead of Optional[Tree], which makes sense but I thought that torch.jit.annotate should be able to do some casting work. The reason why I think so is that in code above, the interpreter takes torch.jit.annotate(Optional[Tree], None) as an Optional[Tree] type value instead of None. I’ve gone through the TorchScript documentation and done some searching, but the only method I’ve found that is able to specify types for a value other than function parameters and return values is by using torch.jit.annotate. Also, I haven’t find any ways which can cast a T typed value to an Optional[T] value. So I am curious that is there anything I miss?

Another problem is below:

@torch.jit.script
class Tree(object):
    def __init__(self):
        self.num_children = torch.jit.annotate(int, 0)

    def add_child(self):
        self.num_children += 1

Error:

RuntimeError: 
left-hand side of augmented assignment to module parameters/buffers can only be tensor types:
at /Users/fere/repos/treelstm.pytorch/treelstm/tree.py:15:9
    def add_child(self):
        self.num_children += 1
        ~~~~~~~~ <--- HERE

Does the error message means that only Tensor-typed values can be updated in a custom TorchScript class?

Thankful and grateful for your help!

First of all—thank you for the feedback! Support for classes is in its early days and reports from intrepid users are super valuable.

To summarize, it seems that you are running into two problems:

  1. torch.jit.annotate() is not doing implicit type promotion from T to Optional[T] for class types.
    This is a bug on our side, and we’ll look into fixing it. Can you file a Github task and we can track it there?

  2. Augmented assignment doesn’t work on non-tensors. This is a known issue, and we have a fix coming this week for it. In the meantime, just re-assigning is an easy workaround:

def add_child(self):
    self.num_children = self.num_children + 1

actually no need to file an issue, https://github.com/pytorch/pytorch/pull/21593 fixes #1.

Update:
We do not actually currently support recursive class definitions. I didn’t remember when I first replied. https://github.com/pytorch/pytorch/pull/21842 improves the error message in this case.

We will likely support it soon (before the next release) but for now it will not work.

Thanks a lot for the update! Guess now I need to wait or look for a temporal workaround. :grinning: