I'm wondering how should I write a tree-lstm in pytorch? Is there any particular functionalities I should look more into? Is there any similar examples I should look at? What is the roughly outline of the implementation?
You can have a look at this implementation - https://gist.github.com/wolet/1b49c03968b2c83897a4a15c78980b18
Or this one, if you want a partially-batched approach like the SPINN paper: https://github.com/jekbradbury/examples/blob/spinn/snli/spinn.py
I already implement the pytorch version here https://github.com/ttpro1995/TreeLSTMSentiment Which I convert from torch code https://github.com/stanfordnlp/treelstm