Max-pooling with complex masks in PyTorch

Dear all,

Suppose I have a matrix src with shape (5, 3) and a boolean matrix adj with shape (5, 5) as follow,

src = tensor(
       [[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14]]
)

and

adj = tensor(
       [[1, 0, 1, 1, 0],
        [0, 1, 1, 1, 0],
        [1, 1, 0, 1, 1],
        [1, 1, 1, 0, 0],
        [0, 0, 1, 0, 1]])

we can take each row in src as one node embedding, and regard each row in adj as the indicator of which nodes are the neighborhood.

My goal is to operate a max-pooling among all neighborhood node embeddings for each node in src.
For example, as the neighborhood nodes (including itself) for the 0-th node is 0, 2, 3, thus we compute a max-pooling on [0, 1, 2], [6, 7,8], [ 9, 10, 11] and lead an updated embedding [ 9, 10, 11] to update 0-th node in src_update.

A simple solution I wrote is

src_update = torch.zeros_like(src)
for index in range(bool_adj.size(0)):
    list_of_non_zero = bool_adj[index].nonzero().view(-1)
    mat_non_zero = torch.index_select(src, 0, list_of_non_zero)
    output[index] = torch.sum(mat_non_zero, dim=0)
print (output)

with src_update is outputted as

src_update = tensor([[ 9, 10, 11],
        [ 9, 10, 11],
        [12, 13, 14],
        [ 6,  7,  8],
        [12, 13, 14]])

Although it works, but it runs very slowly and looks not elegant!
Any suggestions to improve it for better efficiency?

In addition, if both src and adj are appended with batches ((batch, 5, 3), (batch, 5, 5)), how to make it works?

Thanks a lot!!!