Hi there,
I’m relatively new to PyTorch Geometric (I’ve coded up one GNN so far, though have some experience working with PyTorch), and for some research I’m doing, I want to implement the message-passing scheme described on page 4 of this paper. It appears to be more complex than the examples I’ve seen.
The data is in a tree format, and for each neighbour of a node, the message-passing algorithm operates differently (see step 2 in the referenced paper) depending on whether the neighbour is the child or parent of the node (it passes the neighbour’s embedding through one of two MLPs). I’m not sure how I could code my MPNN to recognize when a neighbor is a parent or a child - perhaps it could be encoded as a feature in the edge between them?
The second issue I had was that the parent and child embeddings of a node are aggregated separately - the mean of each is taken, and concatenated with the original node’s embedding (step 3 of the paper). As far as I can see, PyG only allows a simpler aggregation method where all of the neighbors are considered at once.
If anyone has any thoughts on how these might be approached, I’d be very interested to hear them