What's the best practice to implement a tree of models in parallel

Hi all, I’d like to implement the paper “Monocular Depth Estimation Using Neural Regression Forest” (http://web.engr.oregonstate.edu/~sinisa/research/publications/cvpr16_NRF.pdf) using PyTorch.

The model is like this. A model contains many mini-models (splitting nodes) and they are structured in a tree. Every splitting node is a nn.module which receives input from its parent model and sends its output to its children models.

For example, we have a binary tree of height 3 and we need three splitting node models: node0, node1 and node2. The input tensor x is sent to node0 as:
output0 = node0(x),
which is passed to its children:
output1 = node1(output0), output2 = node2(output0)
This process can go on if the tree is deeper…

In addition, we have an ensemble of these trees to form a forest. In the forward pass, it seems not wise to do sequential evaluation using a for loop like this:

splitting_nodes is a list of nn.modules for all splitting node models

for i in range(num_splitting_nodes):
parent_index = find_parent(i)
parent_output = find_output(parent_index)
output_i = splitting_nodesi

In fact, models from different trees are not related. In addition, forward passes of all nodes in the same depth are not related. A parallel implementation is tempting. The question is what’s the best way to implement this in PyTorch for best efficiency.

Should I use multiprocessing to use a pool of threads to evaluate unrelated models? Or should I concatenate models in the same depth into a “big” model (for example, combine several CNNs with identical architecture into one big CNN) and do a single forward pass to get the results for splitting nodes in the same depth?

Will appreciate your experiences and help!