Dijkstra algorithm in Pytorch

I am working on 3D point clouds. I have the SPARSE MATRIX representation of the graph structure of the point cloud (like csr_matrix in scipy.sparse). I want to club together the points that are within certain threshold of the Geodesic distance (approximated by the path length in the graph) and process them together. TO FIND such points, I need to run some shortest path finding algorithm like Dijkstra's. In a nutshell, my idea is like this

1) Sample K points out of N points (that I could do using Furthest Point Sampling)
2) Find the nearest Geodesic neighbours (using BackProp supported algorithm) for each of K points
3) Process the neighbours for each point using some Neural Network

This will go in my forward function.
Is there a way to implement Dijkstra’s(that support BackProp) in my functionality?

Or any other idea that I can implement?

Thank you!

I created my custom implementation for Dijkstra using priority queues as discussed here
For the same, I created a custom PriorityQ class using torch function as below

class priorityQ_torch(object):
    """Priority Q implelmentation in PyTorch

        object ([torch.Tensor]): [The Queue to work on]

    def __init__(self, val):
        self.q = torch.tensor([[val, 0]])
        # self.top = self.q[0]
        # self.isEmpty = self.q.shape[0] == 0

    def push(self, x):
        """Pushes x to q based on weightvalue in x. Maintains ascending order

            q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]
            x ([torch.Tensor]): [[index, weight] tensor to be inserted]

            [torch.Tensor]: [The queue tensor after correct insertion]
        if type(x) == np.ndarray:
            x = torch.tensor(x)
        if self.isEmpty():
            self.q = x
            self.q = torch.unsqueeze(self.q, dim=0)
        idx = torch.searchsorted(self.q.T[1], x[1])
        self.q = torch.vstack([self.q[0:idx], x, self.q[idx:]]).contiguous()

    def top(self):
        """Returns the top element from the queue

            [torch.Tensor]: [top element]
        return self.q[0]

    def pop(self):
        """pops(without return) the highest priority element with the minimum weight

            q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]

            [torch.Tensor]: [highest priority element]
        if self.isEmpty():
            print("Can Not Pop")
        self.q = self.q[1:]

    def isEmpty(self):
        """Checks is the priority queue is empty

            q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value]

            [Bool] : [Returns True is empty]
        return self.q.shape[0] == 0

Now dijkstra, with adjacency matrix(with graph weights as input)

def dijkstra(adj):
    n = adj.shape[0]
    distance_matrix = torch.zeros([n, n])
    for i in range(n):
        u = torch.zeros(n, dtype=torch.bool)
        d = np.inf * torch.ones(n)
        d[i] = 0
        q = priorityQ_torch(i)
        while not q.isEmpty():
            v, d_v = q.top()  # point and distance
            v = v.int()
            if d_v != d[v]:
            for j, py in enumerate(adj[v]):
                if py == 0 and j != v:
                    to = j
                    weight = py
                    if d[v] + py < d[to]:
                        d[to] = d[v] + py
                        q.push(torch.Tensor([to, d[to]]))
        distance_matrix[i] = d
    return distance_matrix

Returns shortest path distance matrix for the graph points!

Hope this helps someone! :slight_smile:

Thank you!