Dear all,
I’m working on a depth first search algorithm using PyTorch Geometric Data objects. The idea is that I have a Minimal Spanning Tree as a (2, E) Tensor and a graph object that contains normal vectors per graph node. We sometimes flip the normal vectors in certain conditions (This is part of an algorithm to construct normal vectors for a pointcloud). I don’t like that the algorithm contains a loop and recursion call, because this is slow in python… Is there a way I can speed this up? Code below and thank you for reading!
@classmethod
def flipNormals(cls, graph: tg_data_Data, mst_edge_index: torch_Tensor) -> None:
r"""
Flips normals in the graph based on the MST.
Args:
graph (Torch Geometric Data): Graph that contains normals that need to be flipped.
mst (Torch Tensor (E,)): Indices that point to the edges that are included in the Minimal Spanning Tree.
"""
_n = graph.n
def dfs_from_node(src, visited):
visited[src] = True
# Get neighbors of the current node
neighbors = mst_edge_index[1, mst_edge_index[0] == src]
# Recursive DFS on unvisited neighbors
for dest in neighbors:
if not visited[dest]:
if (_n[src] * _n[dest]).sum(dim=0) < 0:
_n[dest] *= -1
dfs_from_node(dest, visited)
N = graph.num_nodes
visited = torch_zeros(N, dtype=bool)
start_node = torch_argmax(graph.pos[:, 2])
if _n[start_node, 2] < 0:
_n[start_node] *= -1
dfs_from_node(start_node, visited)