#input x = torch.rand((32, 512)) #10 nodes in a graph nodes = torch.rand(10, 512) #edges between the nodes edges = torch.zeros([10,10]) #each input has an assigned node correct_nodes = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) for i, xi in enumerate(x): #get the correct node for the input correct_node = correct_nodes[i] #for each node in the graph for node in range(nodes.shape): #edge operations between the correct node and other nodes in the graph if edges[correct_node, node] < e_min: edges[correct_node, node] = 0 if edges[correct_node, node] > 0: edges[correct_node, node] = edges[correct_node, node] * epsilon
I might have an idea, but I’m making a few assumptions…
This also might just be useful if in your scenario your edges matrix is larger than the given example.
If it is significantly larger I would suggest looking at sparse representations, and maybe computing this outside of python itself.
Observations, it seems you aren’t using xi when looping over x and you are only using i to go over correct nodes, so this can be written like this.
for i in correct_nodes: for j in range(nodes.shape): if edges[i, j] < e_min: edges[i, j] = 0 if edges[i, j] > 0: edges[i, j] = edges[i, j] * epsilon
Other observation, correct_nodes elements is contained within the (i) 0th dimension of edges.
If all the conditions are the following, then we are simply passing through each elements of a row multiple times for every node in “correct_nodes”, in other words if “correct_nodes” contains duplicates we are going over (i,j) multiple times.
Lastly, I’m going on a limb here and assume epsilon is used to reduce the edge element value.
cond1> if (e < e_min) : e -> 0 (revisit e): if e := 0 we will never reset it. cond2>if e > 0: e -> e * epsilon (revisit e):can trigger cond1. Therefore run: > e:= e * epsilon^(encounters) > and then check condition 1
So maybe something like this would be what you are looking for if correct_nodes is a long list of nodes ?
from collections import Counter visited_node_encounters = torch.ones(10) visits_counter = Counter(visits) for i in range(10): if i in visits_counter: visited_node_encounters[i] = visits_counter[i] encounters = torch.ones(10,10) * visited_node_encounters epsilon = epsilon * torch.ones(10,10) cond2 = epsilon**encounters edges = cond1 = (1 - ((cond2 < epsilon) * 1)) * cond2
*Note here I’m ignoring the order of these conditions, if they do matter apply edges = edges * epsilon^(encounters-1), followed by edges = cond1 * epsilon