I’m trying to make kind of a recursive network that takes in sequential input and produces a binary-tree structured output.
After processing the input and doing some computations, the model produces an intermediate
y_hat of shape
(batch, ), the values of which determine whether each specific sample in the batch needs to be split further or not (1 means split, 0 means don’t). Since different samples in the batch will produce different sized trees, I can’t figure out how to process this efficiently.
Currently I’m thinking of doing something like this :
def buildTree(self, inputs): y_hat = """ Do some computations """ completedTreeBatch = inputs[y_hat == 0] splitTreeBatch = inputs[y_hat == 1] if splitTreeBatch != None: returnTreeBatch = buildTree(self, splitTreeBatch) completedTreeBatch = torch.cat((completedTreeBatch, returnTreeBatch)) return completedTreeBatch def forward(self, inputs): """ Do some computations """ treeOutput = self.buildTree(inputs)
This would involve a lot of unnecessary padding and concatenating. Is there any better way to go about this?
Also, is there a more efficient way to implement something like beam search to build the tree in PyTorch?
Any help would be really appreciated!