Binary Tree lookup with PyTorch?

I am attempting to construct a binary tree lookup table and would like to be able to do so using torch Tensors. Is there a known process for accomplishing this? For example, I’m trying to learn about the torch.nn.Embedding class now and I suspect there may be a way of using this but I’m struggling a bit.

My current implementation relies on creating a class object with either a set of pointers to the branches in this node or a value to be returned. This is then called recursively until I reach a value, but it is quite slow. Intuition tells me there should be a way to generate a tensor which could be used as a lookup if I supply the indices. I’m rather a novice though and there is so much I don’t know about what is possible here.