PyTorch Geometric: Replicating complex message-passing model from paper

Hi there,

I’m relatively new to PyTorch Geometric (I’ve coded up one GNN so far, though have some experience working with PyTorch), and for some research I’m doing, I want to implement the message-passing scheme described on page 4 of this paper. It appears to be more complex than the examples I’ve seen.

The data is in a tree format, and for each neighbour of a node, the message-passing algorithm operates differently (see step 2 in the referenced paper) depending on whether the neighbour is the child or parent of the node (it passes the neighbour’s embedding through one of two MLPs). I’m not sure how I could code my MPNN to recognize when a neighbor is a parent or a child - perhaps it could be encoded as a feature in the edge between them?

The second issue I had was that the parent and child embeddings of a node are aggregated separately - the mean of each is taken, and concatenated with the original node’s embedding (step 3 of the paper). As far as I can see, PyG only allows a simpler aggregation method where all of the neighbors are considered at once.

If anyone has any thoughts on how these might be approached, I’d be very interested to hear them :slight_smile: