Many tutorials/samples of RvNNs (Recursive/Tree Neural Networks) in torch found on the internet use Python’s torch API. This allows them to implement a forward method with the signature
def forward(tree):
....
However, the type of the argument (i.e. tree
) is a tree data structure that we traverse.
When trying to reimplement this in C++ i ran into the problem that to overload the forward
method, the argument needs to be of type torch::Tensor
.
Hence, I am wondering what the standard way of implementing RvNNs is for C++.
One way would probably be to store the tree in the RvNN type and basically use that to traverse the tree. But I suspect that that would make batching etc. impossible.