Graph Depth First Search

Dear all,

I’m working on a depth first search algorithm using PyTorch Geometric Data objects. The idea is that I have a Minimal Spanning Tree as a (2, E) Tensor and a graph object that contains normal vectors per graph node. We sometimes flip the normal vectors in certain conditions (This is part of an algorithm to construct normal vectors for a pointcloud). I don’t like that the algorithm contains a loop and recursion call, because this is slow in python… Is there a way I can speed this up? Code below and thank you for reading!

@classmethod
    def flipNormals(cls, graph: tg_data_Data, mst_edge_index: torch_Tensor) -> None:
        r"""
        Flips normals in the graph based on the MST.

        Args:
            graph (Torch Geometric Data): Graph that contains normals that need to be flipped.
            mst (Torch Tensor (E,)): Indices that point to the edges that are included in the Minimal Spanning Tree.
        """

        _n = graph.n
        def dfs_from_node(src, visited):
            visited[src] = True

            # Get neighbors of the current node
            neighbors = mst_edge_index[1, mst_edge_index[0] == src]

            # Recursive DFS on unvisited neighbors
            for dest in neighbors:
                if not visited[dest]:
                    if (_n[src] * _n[dest]).sum(dim=0) < 0:
                        _n[dest] *= -1
                    dfs_from_node(dest, visited)

        N = graph.num_nodes
        visited = torch_zeros(N, dtype=bool)
        start_node = torch_argmax(graph.pos[:, 2])
        if _n[start_node, 2] < 0:
            _n[start_node] *= -1
        dfs_from_node(start_node, visited)