Efficient way to produce batch specific tree-structured output

Hi,
I’m trying to make kind of a recursive network that takes in sequential input and produces a binary-tree structured output.
After processing the input and doing some computations, the model produces an intermediate y_hat of shape (batch, ), the values of which determine whether each specific sample in the batch needs to be split further or not (1 means split, 0 means don’t). Since different samples in the batch will produce different sized trees, I can’t figure out how to process this efficiently.

Currently I’m thinking of doing something like this :

def buildTree(self, inputs):
  y_hat =  """ Do some computations """
  completedTreeBatch = inputs[y_hat == 0]
  splitTreeBatch = inputs[y_hat == 1]
  if splitTreeBatch != None:  
    returnTreeBatch = buildTree(self, splitTreeBatch)
    completedTreeBatch = torch.cat((completedTreeBatch, returnTreeBatch))
  return completedTreeBatch

def forward(self, inputs):
  """ Do some computations """
  treeOutput = self.buildTree(inputs)

This would involve a lot of unnecessary padding and concatenating. Is there any better way to go about this?
Also, is there a more efficient way to implement something like beam search to build the tree in PyTorch?

Any help would be really appreciated!