Message passing over node embeddings of a graph

I have a matrix of embeddings of size N*d (N inputs each having an embedding of size d) and I have a graph adjacency matrix of size N*N. Assuming the adjacency matrix is as follows (N=4):

[[1,1, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0],
[1, 0, 1, 0]]

And the embedding matrix is as follows (d=3):

[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]

I want to get the following output:

[[18, 21, 24],
[1, 2, 3],
[14, 16, 18],
[0, 0, 0, 0]]

The output is computed as follows: since the first column of the adjacency matrix is 1 at rows 0, 2, and 3, we sum the embeddings at rows 0, 2, 3 and that becomes the first row of the output ([18, 21, 24]). Similarly, since the second column of the adjacency matrix is 1 only at row 0, the second row of the output is just equivalent to row 0 of the embedding. Since the third column of the adjacency matrix is 1 at rows 1 and 3, the third row of the output is the sum of the rows 1 and 3 of the embedding. And since the last column of the adjacency matrix is all zeros, the last row of the output is all zeros.

Can’t figure out how to do it without using a for loop. Thanks for any helps in advance.

I can’t come up with a better solution that this one:

adj_nz = adj.nonzero()
torch.stack([emb[adj_nz[adj_nz[:, 1] == u, 0]].sum(0) for u in torch.unique(adj_nz, sorted=True)])

Since the code might be a bit complicated, here are some explanations:

  • adj_nz is gives a tensor containing indices of all non zero values for each dim
tensor([[0, 0],
        [0, 1],
        [1, 2],
        [2, 0],
        [3, 0],
        [3, 2]])
  • adj_nz[:, 1] == u gives all rows for the current unique value in adj_nz. For u=0:
tensor([1, 0, 0, 1, 1, 0], dtype=torch.uint8)
  • We are using this to index adj_nz again and get the corresponding non-zero row indices using adj_nz[adj_nz[:, 1] == u, 0]:
tensor([0, 2, 3])
  • Once this is done, we can use this result to index emb and sum in dim0 using emb[adj_nz[adj_nz[:, 1] == u, 0]].sum(0):
tensor([18, 21, 24])
  • This has to be repeated for every unique value in adj_nz, which involves a list comprehension and a call to torch.stack. I’m not very happy about this last fact, but am currently unsure how to avoid it or if it might be a bottleneck at all.
1 Like

torch.einsum(“ik,ij->kj”, (adj, embed)) where adj is the given adjacency matrix of size N * N and embed is the embedding matrix of size N * d.

2 Likes

Oh wow, that’s beautiful! And just a single line of code.
@Mehran forget about my approach. Use this one, as it’s cleaner and faster!

Apparently I haven’t played enough with torch.einsum. :slight_smile:

2 Likes