Is there a way to make the following code run faster or to avoid the loop?

x = torch.rand((32, 512))
#10 nodes in a graph
nodes = torch.rand(10, 512)
#edges between the nodes
edges = torch.zeros([10,10])
#each input has an assigned node
correct_nodes = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1])

for i, xi in enumerate(x):
    #get the correct node for the input
    correct_node = correct_nodes[i]
    #for each node in the graph
    for node in range(nodes.shape[0]):
        #edge operations between the correct node and other nodes in the graph
        if edges[correct_node, node] < e_min:
            edges[correct_node, node] = 0
        if edges[correct_node, node] > 0:
            edges[correct_node, node] = edges[correct_node, node] * epsilon

I might have an idea, but I’m making a few assumptions…
This also might just be useful if in your scenario your edges matrix is larger than the given example.
If it is significantly larger I would suggest looking at sparse representations, and maybe computing this outside of python itself.

Observations, it seems you aren’t using xi when looping over x and you are only using i to go over correct nodes, so this can be written like this.

for i in correct_nodes:
    for j in range(nodes.shape[0]):
        if edges[i, j] < e_min:
            edges[i, j] = 0
        if edges[i, j] > 0:
            edges[i, j] = edges[i, j] * epsilon

Other observation, correct_nodes elements is contained within the (i) 0th dimension of edges.
If all the conditions are the following, then we are simply passing through each elements of a row multiple times for every node in “correct_nodes”, in other words if “correct_nodes” contains duplicates we are going over (i,j) multiple times.
Lastly, I’m going on a limb here and assume epsilon is used to reduce the edge element value.

cond1> if (e < e_min) : e -> 0
(revisit e): if e := 0 we will never reset it.
cond2>if e > 0:  e -> e * epsilon 
(revisit e):can trigger cond1.
Therefore run:
> e:= e * epsilon^(encounters)
> and then check condition 1

So maybe something like this would be what you are looking for if correct_nodes is a long list of nodes ?

from collections import Counter
visited_node_encounters = torch.ones(10)
visits_counter = Counter(visits)
for i in range(10):
    if i in visits_counter:
        visited_node_encounters[i] = visits_counter[i]
encounters = torch.ones(10,10) * visited_node_encounters
epsilon = epsilon * torch.ones(10,10)

cond2 = epsilon**encounters

edges = cond1 = (1 - ((cond2 < epsilon) * 1)) * cond2

*Note here I’m ignoring the order of these conditions, if they do matter apply edges = edges * epsilon^(encounters-1), followed by edges = cond1 * epsilon