Any efficient way to compute batched minimum spanning tree?


In my research I need to compute the minimum spanning trees (MST) from a batch of adjacency matrix of size B x N x N, where the N x N dimension refers to a matrix whose elements are distances, and B is the batch dimenion. And I want to find a way to output a matrix of size B x N x N, where N x N is the MST computed from the given adjacency matrices.

The problem is to deal with the batch. For non-batched version I plan on calling scipy.sparse.csgraph.minimum_spanning_tree, but it doe not seem to support a batch dimension, and computing the mst sequentially would be too slow.


%% I have the feeling that it would be really nice if torch.sparse could support some sort of MST algorithm for gpu?


I think the problem is that algorithms on graphs are quite tricky to implement on gpu.
Moreover, they are not used a lot in the context of neural networks and I’m not aware of any implementation.

Hmmmmm, I see, maybe I will have to implement this myself, do you have any suggestions about how I should implement it?

Maybe I will just implement a non-batched version, since my graphs are pretty small (about 20x20), the speed might still be standable,

The thing is I really want to avoid copying things back and forth between CPU and GPU, which I think defeats the purpose of doing batch learning on GPus

Hi Zing, Its very late to reply to this thread, but I wanted to know if you implemented MST with batched input? If so, it will be very interesting to learn your approach.