```
#input
x = torch.rand((32, 512))
#10 nodes in a graph
nodes = torch.rand(10, 512)
#edges between the nodes
edges = torch.zeros([10,10])
#each input has an assigned node
correct_nodes = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
for i, xi in enumerate(x):
#get the correct node for the input
correct_node = correct_nodes[i]
#for each node in the graph
for node in range(nodes.shape[0]):
#edge operations between the correct node and other nodes in the graph
if edges[correct_node, node] < e_min:
edges[correct_node, node] = 0
if edges[correct_node, node] > 0:
edges[correct_node, node] = edges[correct_node, node] * epsilon
```

I might have an idea, but I’m making a few assumptions…

This also might just be useful if in your scenario your edges matrix is larger than the given example.

If it is significantly larger I would suggest looking at sparse representations, and maybe computing this outside of python itself.

Observations, it seems you aren’t using xi when looping over x and you are only using i to go over correct nodes, so this can be written like this.

```
for i in correct_nodes:
for j in range(nodes.shape[0]):
if edges[i, j] < e_min:
edges[i, j] = 0
if edges[i, j] > 0:
edges[i, j] = edges[i, j] * epsilon
```

Other observation, correct_nodes elements is contained within the (i) 0th dimension of edges.

If all the conditions are the following, then we are simply passing through each elements of a row multiple times for every node in “correct_nodes”, in other words if “correct_nodes” contains duplicates we are going over (i,j) multiple times.

Lastly, I’m going on a limb here and assume epsilon is used to reduce the edge element value.

```
cond1> if (e < e_min) : e -> 0
(revisit e): if e := 0 we will never reset it.
cond2>if e > 0: e -> e * epsilon
(revisit e):can trigger cond1.
Therefore run:
> e:= e * epsilon^(encounters)
> and then check condition 1
```

**So maybe something like this would be what you are looking for if correct_nodes is a long list of nodes ?**

```
from collections import Counter
visited_node_encounters = torch.ones(10)
visits_counter = Counter(visits)
for i in range(10):
if i in visits_counter:
visited_node_encounters[i] = visits_counter[i]
encounters = torch.ones(10,10) * visited_node_encounters
epsilon = epsilon * torch.ones(10,10)
cond2 = epsilon**encounters
edges = cond1 = (1 - ((cond2 < epsilon) * 1)) * cond2
```

*Note here I’m ignoring the order of these conditions, if they do matter apply edges = edges * epsilon^(encounters-1), followed by edges = cond1 * epsilon