Hi all, I meet a simple problem in torch.
I have a vector num_nodes which indicate the number of nodes on each graph, then I want to get vector batch to present which graph each node belongs to, it can be written in for-loop easily but has high latency, is there any operation I can use to vectorize it?
# input
# num_nodes: [2,3]
for i, num in enumerate(num_nodes):
batch[cum : cum + num] = i
cum += num
# output
# batch: [0, 0, 1, 1, 1]