Hi,
I’m trying to make kind of a recursive network that takes in sequential input and produces a binary-tree structured output.
After processing the input and doing some computations, the model produces an intermediate y_hat
of shape (batch, )
, the values of which determine whether each specific sample in the batch needs to be split further or not (1 means split, 0 means don’t). Since different samples in the batch will produce different sized trees, I can’t figure out how to process this efficiently.
Currently I’m thinking of doing something like this :
def buildTree(self, inputs):
y_hat = """ Do some computations """
completedTreeBatch = inputs[y_hat == 0]
splitTreeBatch = inputs[y_hat == 1]
if splitTreeBatch != None:
returnTreeBatch = buildTree(self, splitTreeBatch)
completedTreeBatch = torch.cat((completedTreeBatch, returnTreeBatch))
return completedTreeBatch
def forward(self, inputs):
""" Do some computations """
treeOutput = self.buildTree(inputs)
This would involve a lot of unnecessary padding and concatenating. Is there any better way to go about this?
Also, is there a more efficient way to implement something like beam search to build the tree in PyTorch?
Any help would be really appreciated!