I’m wondering how should I write a tree-lstm in pytorch? Is there any particular functionalities I should look more into? Is there any similar examples I should look at? What is the roughly outline of the implementation?
Can you explain the leaf / base condition - self.ox just passes through a linear layer, that is understandable since there is no hidden state for leafs, but why is there no weight params for input, update or forget gating, and why is cell state just passed through a linear layer?
And also for non-leaf nodes, you are completely ignoring passing the input through a linear layer, for all the gating units. Is there an explanation for that? In ChildSum, you have weight parameters for x_j, why not in n-ary lstm ?
Tree LSTMs are conceptually straightforward extension of RNN-LSTMs. Here’s what we need:
A parser (takes in a sentence and outputs the parse tree structure).
Stanford Core NLP/your own parser
Given the parse tree structure which implicitly contains how the word units should be progressively combined, convert this into a series of instructions which explicitly describes how the nodes should be combined.
Write a RNN that takes in a series of instructions on how to combine a list of inputs.