How to write tree-lstm in pytorch?

Hi all,

I’m wondering how should I write a tree-lstm in pytorch? Is there any particular functionalities I should look more into? Is there any similar examples I should look at? What is the roughly outline of the implementation?



1 Like

You can have a look at this implementation -

1 Like

Or this one, if you want a partially-batched approach like the SPINN paper:

I already implement the pytorch version here
Which I convert from torch code

In your implementation of the BinaryTreeLSTM,

  1. Can you explain the leaf / base condition - self.ox just passes through a linear layer, that is understandable since there is no hidden state for leafs, but why is there no weight params for input, update or forget gating, and why is cell state just passed through a linear layer?

  2. And also for non-leaf nodes, you are completely ignoring passing the input through a linear layer, for all the gating units. Is there an explanation for that? In ChildSum, you have weight parameters for x_j, why not in n-ary lstm ?

self.ix = nn.Linear(self.in_dim,self.mem_dim)

A implementation with easy-first parsing, with explanations

Tree LSTMs are conceptually straightforward extension of RNN-LSTMs. Here’s what we need:

  1. A parser (takes in a sentence and outputs the parse tree structure).
  • Stanford Core NLP/your own parser
  1. Given the parse tree structure which implicitly contains how the word units should be progressively combined, convert this into a series of instructions which explicitly describes how the nodes should be combined.

  2. Write a RNN that takes in a series of instructions on how to combine a list of inputs.

perhaps using a well tested library for graphs is the way to go: Tutorial: Tree-LSTM in DGL — DGL 0.4.3post2 documentation