# Improving Function Performance

Hi folks, I’m seeking some help with my function below. I’d like to get some opinions on if there’s a way to improve the inference performance of my function.

Here’s some context of what I want to do. Beliefs is `batch x n x m` matrix, where each `batch` is basically a factor graph with `n` nodes (rows of the matrix) and `m` factors (columns of the matrix), where the graph is bipartite (`n` nodes on one side, `m` factors on the other, edges only from `n` to `m` and vice versa). From this, I calculate messages known as `fac2var` and `var2fac` which are factor to variable messages and variable to factor messages respectively. For factors, the message it receives is the 2nd lowest incoming message from the variable nodes (`n` nodes). As you can see, this is done via the `for-loop` in the message propagation. That section is extremely slow at the moment, as there’s an iterative `for-loop` that is going through.

Basically, I loop through all `n` nodes, and mask out the message it would be sending for each factor. Then, for every factor, I will find the 2nd lowest message it receives. This is done in such a fashion because for a factor `m`, the message it sends to `n` will be the 2nd lowest message it receives, NOT including the message it received from node `n`. Hence, I need to mask out the messages accordingly.

I was wondering if this can be done in a `torch.scatter` or `torch.gather` way, which could potentially speed it up a fair bit? I’m open to suggestions too!

Thanks a lot for the help.

``````    def constraint_bp(beliefs, adj, bp_iter):
# For batch-level implementation, the beliefs is arranged as [batch_size, max_n, max_m]

# We can construct a few matrices to handle the problem. We have 3 matrices, belief matrix (b x nx1),
# adjacency matrix (b x n x m) and the factor->var matrix (b x n x m),
# where n is the number of variable nodes, and m the number of factors
fac2var = torch.zeros((batch_size, n, m), dtype=torch.float, device=beliefs.device)
# Add a mask which inflates values by 99999 so that disconnected nodes are not considered in the message aggregation

lamb = 0.2
# Find all variable to factor messages
# BP algorithm

# Note that for an updated var2fac message, we can simply take the beliefs subtract away by the
# previous fac2var message for that variable. The local beliefs were
# initialized with d(e), which is the local factor for the variable. Hence, the total beliefs of the node
# is given by d(e) + all incoming messages (which are essentially damped updates). Thus, we can always just take
# the message on the edge out. Also note that if we were to convert this problem from MIN to MAX,
# we need to be careful of how we treat d(e). Right now, there's a general assumption
# that all updates are POSITIVE, because of the nature of d(e).

for rounds in range(bp_iter):
if rounds != 0:
var2fac = torch.multiply(beliefs, adj) - fac2var

# Compile message for every factor, we can use the adjacency matrix to mask away the msg
# To calculate the message of column_j (fac_j) to row_i (var_i), we take the var2fac messages,
# mask away row_i, take the min/max of column_j
fac2var_tilde = torch.zeros((batch_size, n, m), dtype=torch.float, device=beliefs.device)