# Message passing over node embeddings of a graph

I have a matrix of embeddings of size `N*d` (N inputs each having an embedding of size d) and I have a graph adjacency matrix of size `N*N`. Assuming the adjacency matrix is as follows (N=4):

[[1,1, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0],
[1, 0, 1, 0]]

And the embedding matrix is as follows (d=3):

[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]

I want to get the following output:

[[18, 21, 24],
[1, 2, 3],
[14, 16, 18],
[0, 0, 0, 0]]

The output is computed as follows: since the first column of the adjacency matrix is 1 at rows 0, 2, and 3, we sum the embeddings at rows 0, 2, 3 and that becomes the first row of the output ([18, 21, 24]). Similarly, since the second column of the adjacency matrix is 1 only at row 0, the second row of the output is just equivalent to row 0 of the embedding. Since the third column of the adjacency matrix is 1 at rows 1 and 3, the third row of the output is the sum of the rows 1 and 3 of the embedding. And since the last column of the adjacency matrix is all zeros, the last row of the output is all zeros.

Canâ€™t figure out how to do it without using a for loop. Thanks for any helps in advance.

I canâ€™t come up with a better solution that this one:

``````adj_nz = adj.nonzero()
``````

Since the code might be a bit complicated, here are some explanations:

• `adj_nz` is gives a tensor containing indices of all non zero values for each dim
``````tensor([[0, 0],
[0, 1],
[1, 2],
[2, 0],
[3, 0],
[3, 2]])
``````
• `adj_nz[:, 1] == u` gives all rows for the current unique value in `adj_nz`. For `u=0`:
``````tensor([1, 0, 0, 1, 1, 0], dtype=torch.uint8)
``````
• We are using this to index `adj_nz` again and get the corresponding non-zero row indices using `adj_nz[adj_nz[:, 1] == u, 0]`:
``````tensor([0, 2, 3])
``````
• Once this is done, we can use this result to index `emb` and sum in dim0 using `emb[adj_nz[adj_nz[:, 1] == u, 0]].sum(0)`:
``````tensor([18, 21, 24])
``````
• This has to be repeated for every unique value in `adj_nz`, which involves a list comprehension and a call to `torch.stack`. Iâ€™m not very happy about this last fact, but am currently unsure how to avoid it or if it might be a bottleneck at all.
1 Like

torch.einsum(â€śik,ij->kjâ€ť, (adj, embed)) where adj is the given adjacency matrix of size `N * N` and embed is the embedding matrix of size `N * d`.

2 Likes

Oh wow, thatâ€™s beautiful! And just a single line of code.
@Mehran forget about my approach. Use this one, as itâ€™s cleaner and faster!

Apparently I havenâ€™t played enough with `torch.einsum`.

2 Likes