How to vectorize pytorch code (Graph Neural Net)

Problem

I have made a PyTorch implementation of a model which is basically a Graph Neural Net (GNN) as I understand it from here. I’m representing first-order logic statements (clauses) as trees and then hoping to come up with a vector embedding for them using my PyTorch model. My hope is that I can feed this embedding as input to a binary classifier which will be trained end-to-end with this embedding model. The goal is to determine whether a logical statement can be useful in a mathematical proof of a given conjecture. Therefore, the binary classifier will take a vector embedding of the clause and a vector embedding of the conjecture as input and try to output whether or not it thinks the clause is useful for proving the conjecture.

This is all good and my model can even train, but I believe it should be able to train much faster. I considered using PyTorch Geometric to implement my GNN, but another restriction is that I need it to be able to be run from libtorch in C++ so that I can embed it into a theorem prover written in C.

I thought that calling torch.jit.script on my model before training would make it train faster, but I had no such luck. Since I was out of ideas, I thought maybe I should try rewriting my entire model in libtorch in C++. (C++ is fast, right?!) Now that I’ve done that, I’m running into weird issues with my backward pass being excruciatingly slow (like a 30 seconds for one backward pass). I saw a page about that here, but didn’t find it very applicable or helpful to my case.

People here told me before that C++ wouldn’t speed up a model in general (because both python and C++ call into the same backend functions), but I still believed it should speed up my code since the slowest part of my python model (by far) was a few dynamic nested python for-loops that I couldn’t figure out how to vectorize into a small number of nice PyTorch operations.

I’m sorry that this question is super long, but it’s hard to explain where my difficulty lies…I don’t know what else to try to make my model better. Do you have any ideas about how to make the below code more pythonic? (PyTorch-ic?). More generally do you have any tricks about how to vectorize complicated logic? Any help would be greatly appreciated.

More context about the model:

The goal of the model is to take in a graph and embed it as a fixed length vector.

The model involves iteratively updating the representation of each node in a tree via a function of itself and its children. I’ve represented that “function of itself and its children” as a submodule which is fairly simple (a concatenation and linear layer or two).

After updating every node, we repeat this process for the length of the tree. (The intuition would be to pass information up from the leaves to the root.) Finally, we choose to represent the entire clause as the vector that has settled at the root of this tree.


class TreeEmbedding(nn.Module):
    """Turns a clause/context tree into a vector."""
    
    def __init__(self, signatureSize, symbolDim):
        super().__init__()
        self.symbolVectors = torch.nn.Parameter(torch.randn(signatureSize, symbolDim))

        self.nodeEmbedding = NodeEmbedding(symbolDim, symbolDim)
        self.neighborhoodEmbedding = NeighborhoodEmbedding()

    def forward(self, tree: List[List[int]], 
                      treeSymbols: List[int]):

        nodeNumbers = [v[0] for v in tree]
        nodeChildren = [v[1:] for v in tree]

        root = nodeNumbers[0]

        nodeVecs = [self.symbolVectors[num] for num in [treeSymbols[i] for i in nodeNumbers]]
        nodeVecs = torch.stack(nodeVecs)

        for i in range(getTreeDepth(tree)):
            fs = self.nodeEmbedding(nodeVecs)
            newNodeVecs = []
            for i in range(len(nodeNumbers)):
                children = []
                for j in nodeChildren[i]:
                    children.append(fs[j])
                    
                newNodeVecs.append(self.neighborhoodEmbedding(fs[i], children))
            nodeVecs = torch.stack(newNodeVecs)

        out = nodeVecs[root]
        return out


Hi,

For the python code:
On comment I would have is that the children discovery could be done only once outside of the tree depth loop and then reused.
Also I am not sure how you know at the end that nodeVecs[root] actually contains what you want. Because the content of nodeVecs changes at every iteration of the loop. Or maybe I’m missing something here.

For c++ vs python:
I think it will greatly depend on the size of your embeddings, connectivity of your graph etc.
With big enough embeddings, the compute will be predominant over other runtime and there is no benefit to be expected from moving to cpp.
For very small embeddings, it might help.

Also I am not sure how you know at the end that `nodeVecs[root]` actually contains what you want.

My hope is that information from the entire tree gets bubbled up to the root node over the iterations and that the root would then become a suitable representation of the entire tree. I was hoping that backpropogation would be able to learn this. I could also do max-pooling or averaging over the nodes, but I figured for a tree, this might work better?

I think I see what you mean with the children discovery being moved out. This would make it two doubly nested for loops instead of one triply nested for loop, correct? (I’d need two for loops to do the children discovery since I need to find all children of all nodes.) I’ll try that and see if it gets much better.
Thanks for the suggestion!

More generally, do you have any tips for speeding up code which isn’t easily vectorized into pytorch and which might generate large computational graphs for which backprop might be slow?

The only tip here would be to take the time to vectorize it.
Note that one important step here is to have the proper representation of your data: you can do very expensive reprocessing of your data once to put it in the shape you want if that makes the forward faster.

For example, you could represent your graph with an adjacency matrix so that accumulating the features from all the neighbors can be done using a single matrix matrix multiplications.
Or store indices in a linearized way to be able to gather all the neighbors in a single gather.
But this will depend a lot on the constraints given by your aggregation function.

1 Like

I thought some more about your suggestion of moving the children discovery out and I now don’t think it’s possible.

I’ve rewritten the code to hopefully be a bit more readable:

for i in range(getTreeDepth(tree)):
    fs = self.nodeEmbedding(nodeVecs)
    newNodeVecs = []
    for j in range(len(treeSymbols)):
        children = [fs[k] for k in nodeChildren[j]]                   
        newNodeVecs.append(self.neighborhoodEmbedding(fs[j], children))
    nodeVecs = torch.stack(newNodeVecs)

I don’t think I can move the “children discovery” out because it needs to reference the current vector values at the children. This depends on the outermost loop here. Do you see something I’m missing here?

Thanks,
Jack

I am still confused about the content of nodeVecs.
It seems that you fill it by adding the neighbors of the node in order.
But the initial value does not seem to match this.

Thanks for questioning my code because I think it was wrong…if nothing else it was confusing enough that I couldn’t understand it. I decided to structure my data slightly different so the code could access it a bit nicer.

Now it looks like this:



class NeighborhoodEmbedding(nn.Module):
    """Combines node embeddings for a node and its neighbors/children into a single vector."""

    def forward(self, fs: torch.Tensor, tree: List[List[int]]):
        childrenVecs = [fs[t] for t in tree]

        # average each node's representation with its children's representations as vectors.
        vecs = [torch.cat([x.reshape(1,-1),y]).mean(0) for x,y in zip(fs, childrenVecs)]
        return torch.stack(vecs)


class TreeEmbedding(nn.Module):
    """Turns a clause/context tree into a vector."""
    
    def __init__(self, signatureSize, symbolDim):
        super().__init__()
        self.symbolVectors = torch.nn.Parameter(torch.randn(signatureSize, symbolDim))

        self.nodeEmbedding = NodeEmbedding(symbolDim, symbolDim)
        self.neighborhoodEmbedding = NeighborhoodEmbedding()

    def forward(self, tree: List[List[int]], 
                      treeSymbols: List[int]):
        """tree lists the children of each node. 
        So if tree[3] == [6,10], this means that node 3 has two children 
        and that its grandchildren are specified at tree[6] and tree[10]"""

        nodeVecs = self.symbolVectors[treeSymbols]

        for i in range(getTreeDepth(tree)):
            fs = self.nodeEmbedding(nodeVecs)
            nodeVecs = self.neighborhoodEmbedding(fs, tree)

        out = nodeVecs[0]
        return out

I think this works and is a little more efficient than before since that triply nested for-loop is now only doubly nested with a pytorch operation kind of working as a stand-in for the other loop.

I think I could do it in a better way if pytorch supported. For instance I wish [fs[t] for t in tree] could be simply fs[tree]

Can you see any optimizations I can perform to make this better?
Is it now clearer what computation I’m performing?

The self.nodeEmbedding is simply a shallow fully connected NN which doesn’t change the dimension of its input. The self.neighborEmbedding is meant to take the transformed node vectors and the tree (explained in the docstring in the code) and return updated node vectors. Right now I’m implementing this simply as each node’s vector getting mapped to the average of it and its children.

Thanks again for all your help!

Hi,

This looks clearer indeed.
One proposal I would have would be to make a big 0-1 matrix based on the tree content. And use that to perform the neighborhoodEmbedding as a single matrix matrix multiplication:
You need a 1 for a node and each of its neighbors and expand it to the feature dimension size.

1 Like