Problem
I have made a PyTorch implementation of a model which is basically a Graph Neural Net (GNN) as I understand it from here. I’m representing first-order logic statements (clauses) as trees and then hoping to come up with a vector embedding for them using my PyTorch model. My hope is that I can feed this embedding as input to a binary classifier which will be trained end-to-end with this embedding model. The goal is to determine whether a logical statement can be useful in a mathematical proof of a given conjecture. Therefore, the binary classifier will take a vector embedding of the clause and a vector embedding of the conjecture as input and try to output whether or not it thinks the clause is useful for proving the conjecture.
This is all good and my model can even train, but I believe it should be able to train much faster. I considered using PyTorch Geometric to implement my GNN, but another restriction is that I need it to be able to be run from libtorch in C++ so that I can embed it into a theorem prover written in C.
I thought that calling torch.jit.script on my model before training would make it train faster, but I had no such luck. Since I was out of ideas, I thought maybe I should try rewriting my entire model in libtorch in C++. (C++ is fast, right?!) Now that I’ve done that, I’m running into weird issues with my backward pass being excruciatingly slow (like a 30 seconds for one backward pass). I saw a page about that here, but didn’t find it very applicable or helpful to my case.
People here told me before that C++ wouldn’t speed up a model in general (because both python and C++ call into the same backend functions), but I still believed it should speed up my code since the slowest part of my python model (by far) was a few dynamic nested python for-loops that I couldn’t figure out how to vectorize into a small number of nice PyTorch operations.
I’m sorry that this question is super long, but it’s hard to explain where my difficulty lies…I don’t know what else to try to make my model better. Do you have any ideas about how to make the below code more pythonic? (PyTorch-ic?). More generally do you have any tricks about how to vectorize complicated logic? Any help would be greatly appreciated.
More context about the model:
The goal of the model is to take in a graph and embed it as a fixed length vector.
The model involves iteratively updating the representation of each node in a tree via a function of itself and its children. I’ve represented that “function of itself and its children” as a submodule which is fairly simple (a concatenation and linear layer or two).
After updating every node, we repeat this process for the length of the tree. (The intuition would be to pass information up from the leaves to the root.) Finally, we choose to represent the entire clause as the vector that has settled at the root of this tree.
class TreeEmbedding(nn.Module):
"""Turns a clause/context tree into a vector."""
def __init__(self, signatureSize, symbolDim):
super().__init__()
self.symbolVectors = torch.nn.Parameter(torch.randn(signatureSize, symbolDim))
self.nodeEmbedding = NodeEmbedding(symbolDim, symbolDim)
self.neighborhoodEmbedding = NeighborhoodEmbedding()
def forward(self, tree: List[List[int]],
treeSymbols: List[int]):
nodeNumbers = [v[0] for v in tree]
nodeChildren = [v[1:] for v in tree]
root = nodeNumbers[0]
nodeVecs = [self.symbolVectors[num] for num in [treeSymbols[i] for i in nodeNumbers]]
nodeVecs = torch.stack(nodeVecs)
for i in range(getTreeDepth(tree)):
fs = self.nodeEmbedding(nodeVecs)
newNodeVecs = []
for i in range(len(nodeNumbers)):
children = []
for j in nodeChildren[i]:
children.append(fs[j])
newNodeVecs.append(self.neighborhoodEmbedding(fs[i], children))
nodeVecs = torch.stack(newNodeVecs)
out = nodeVecs[root]
return out