What is the standard way to batch and do a forward pass through tree structured data (ASTs) in pytorch so to leverage the power of GPUs?

In vision we usually process multiple images at once. I believe this is possible because most images are the same size or we can easily pad them with zeros if they are not (and thus process many at once).

However, I don’t see a simple way to do this for structure data like data (e.g. programs, code, not NLP) in the form of Trees (e.g. while using a TreeLSTM). It seems that the processing of the data and the forward pass of the TreeLSTM are tightly coupled. e.g. if our TreeLSTM is processing the Abstract Syntax Tree (AST) top-down or bottom-up my guess is that the batching has to be different and custom for each of these. What I am imagining very vaguely and abstracting is that we get each node from the AST, generate the tensor for that node and then do the 1 step forward pass. If we were to do this for many ASTs I’d assume to do this style of batching we’d need to collect ASTs of the same topology (not only size, since I assume even if two trees are the same size their branches everywhere have to match somehow…unless we padd them with zeros?). That way I’d imagine the code to be something like this:

def forward(asts):
    """ asts have already been collated according to their matching topology. """
    # order trees according to how the TreeNN will process it e.g. post order for bottom up
    order_asts_according_to_traversal = oder_asts(asts)  ## e.g. order in a post order
    # do the forward pass
    current_node_embeddings =  torch.zeros(batch_size, D_embedding)
    for batch_node in order_asts_according_to_traversal:
        current_node_embeddings = node_step_forward(batch_node, current_node_embeddings)
    return current_node_embeddings

in addition, we’d have to collate the asts (pseudo-code that I didn’t write since I’m not sure how to do this). For me, my data is usually already parsed (e.g. in json or s-expressions, but I prefer and plan to use json for each data point).

The problem that I have above does not feel new. TreeNNs (e.g. TreeLSTMs) have been implemented before so I’d rather use a tested method - the standard way to process and do forward passes on tree data for pytorch. Is there such a tutorial or library for pytorch one could easily use?


Research I’ve done

I’ve found the following options:

  1. using CoqGym’s example (How to efficiently process a batch of terms in one go for Tree Neural Networks? · Discussion #48 · princeton-vl/CoqGym · GitHub and How to efficiently process a batch of terms in one go for Tree Neural Networks · Issue #16 · princeton-vl/CoqGym · GitHub) though that general repo feels very hacky and their way of doing things seem very very custom, which worries me.
  2. this option: PyTorch — Dynamic Batching. If you have been reading my blog, you… | by Illia Polosukhin | NEAR AI | Medium
  3. a pytorch implementation of tree lstms: GitHub - unbounce/pytorch-tree-lstm: Pytorch implementation of the child-sum Tree-LSTM model
  4. yet another link: Recursive Neural Networks with PyTorch | NVIDIA Developer Blog

what worries me is that none of these seemed to be part of the standard pytorch ecosystem (Ecosystem | PyTorch), which raises doubt to me which one to use (if any of them). Should I be using any of these above? My guess is that perhaps the 4th one is the one I should be using? Ideally, I’d like to be able to implement a custom data set object that returns ASTs (from my Json file or s-expressions) and create a batch that takes advantage of GPU acceleration.


Another Solution: using Graph Neural Network libraries

I also realized that tree are really just a special case of graphs (perhaps not in the way we do the forward pass since GNNs seem to have a more complex set of “cycles” to produce embeddings). So with that I found these links:

  1. question from the pytorch forum: What is an efficient way to do a forward pass of a batch of data for an Graph Neural Network (GNN)? - #2 by Brando_Miranda
  2. deep graph library: https://www.dgl.ai/
  3. geometric library: GitHub - rusty1s/pytorch_geometric: Geometric Deep Learning Extension Library for PyTorch

with the last two being part of the pytorch ecosystem. Though these are built for GNNs, so I am unsure if they are really worth the effort to check out to decide what to use.

Overall: What is the standard/recommended way to generate batches and process batches of tree data in pytorch to leverage the power of GPUs?


cross posted:

2 Likes

A possible solution is to use a graph NN library to do the heavy work:

The challenge in training Tree-LSTMs is batching — a standard technique in machine learning to accelerate optimization. However, since trees generally have different shapes by nature, parallization is non-trivial. DGL offers an alternative. Pool all the trees into one single graph then induce the message passing over them, guided by the structure of eac

e.g. from dgl: Tutorial: Tree-LSTM in DGL — DGL 0.4.3post2 documentation