Any efficient way to compute batched minimum spanning tree?

Hi!

In my research I need to compute the minimum spanning trees (MST) from a batch of adjacency matrix of size B x N x N, where the N x N dimension refers to a matrix whose elements are distances, and B is the batch dimenion. And I want to find a way to output a matrix of size B x N x N, where N x N is the MST computed from the given adjacency matrices.

The problem is to deal with the batch. For non-batched version I plan on calling scipy.sparse.csgraph.minimum_spanning_tree, but it doe not seem to support a batch dimension, and computing the mst sequentially would be too slow.

Thanks!

%% I have the feeling that it would be really nice if torch.sparse could support some sort of MST algorithm for gpu?

1 Like

Hi,

I think the problem is that algorithms on graphs are quite tricky to implement on gpu.
Moreover, they are not used a lot in the context of neural networks and I’m not aware of any implementation.

Hmmmmm, I see, maybe I will have to implement this myself, do you have any suggestions about how I should implement it?

Maybe I will just implement a non-batched version, since my graphs are pretty small (about 20x20), the speed might still be standable,

The thing is I really want to avoid copying things back and forth between CPU and GPU, which I think defeats the purpose of doing batch learning on GPus

Hi Zing, Its very late to reply to this thread, but I wanted to know if you implemented MST with batched input? If so, it will be very interesting to learn your approach.

I recently ran into the same problem, and here’s my solution combining Prim’s algorithm & array operations. Perhaps this could be useful for someone.

The following function assumes the input log-weight matrix, logQ, having dimensions S x nV x nV, with S the batch size, nV the number of nodes,

def computeMaxSpanningTrees(logQ):

    logQ -= logQ.max()
    
    Q = logQ.exp()

    nV = Q.shape[1]

    grid = torch.arange(0, (nV)*(nV)).reshape(nV,nV).triu(1)

    triuIdx = grid[grid>0]

    seqS= torch.arange(0,S)

    def extractTriuAndFlatten(A):
        return A.reshape(S,nV*nV)[:, triuIdx]

    inTree = torch.zeros((S, nV, 1), dtype=torch.bool)
    inTree[:,nV-1] = True

    triuBool = (grid>0)
    Q_flat = Q.reshape([S, nV*nV])[:,triuIdx]

    A_trees= torch.zeros(S,nV,nV)

    for step in range(nV-1):

        validEdgeIdx = (inTree != torch.transpose(inTree,1,2)) & triuBool
        drawnEdgeIdx = triuIdx[ (Q_flat * extractTriuAndFlatten(validEdgeIdx)).argmax(1)]

        drawn_j = drawnEdgeIdx % nV
        drawn_i = (drawnEdgeIdx / nV).long()

        inTree[seqS,drawn_i]=True
        inTree[seqS,drawn_j]=True
        A_trees[seqS,drawn_i,drawn_j]= True
        A_trees[seqS,drawn_j,drawn_i]= True
    
    return A_trees