Hi folks, I’m seeking some help with my function below. I’d like to get some opinions on if there’s a way to improve the inference performance of my function.

Here’s some context of what I want to do. Beliefs is `batch x n x m`

matrix, where each `batch`

is basically a factor graph with `n`

nodes (rows of the matrix) and `m`

factors (columns of the matrix), where the graph is bipartite (`n`

nodes on one side, `m`

factors on the other, edges only from `n`

to `m`

and vice versa). From this, I calculate messages known as `fac2var`

and `var2fac`

which are factor to variable messages and variable to factor messages respectively. For factors, the message it receives is the 2nd lowest incoming message from the variable nodes (`n`

nodes). As you can see, this is done via the `for-loop`

in the message propagation. That section is extremely slow at the moment, as there’s an iterative `for-loop`

that is going through.

Basically, I loop through all `n`

nodes, and mask out the message it would be sending for each factor. Then, for every factor, I will find the 2nd lowest message it receives. This is done in such a fashion because for a factor `m`

, the message it sends to `n`

will be the 2nd lowest message it receives, NOT including the message it received from node `n`

. Hence, I need to mask out the messages accordingly.

I was wondering if this can be done in a `torch.scatter`

or `torch.gather`

way, which could potentially speed it up a fair bit? I’m open to suggestions too!

Thanks a lot for the help.

```
def constraint_bp(beliefs, adj, bp_iter):
# For batch-level implementation, the beliefs is arranged as [batch_size, max_n, max_m]
# We can construct a few matrices to handle the problem. We have 3 matrices, belief matrix (b x nx1),
# adjacency matrix (b x n x m) and the factor->var matrix (b x n x m),
# where n is the number of variable nodes, and m the number of factors
batch_size, n, m = adj.size()
fac2var = torch.zeros((batch_size, n, m), dtype=torch.float, device=beliefs.device)
# Add a mask which inflates values by 99999 so that disconnected nodes are not considered in the message aggregation
mask = adj.clone().detach()
mask[mask == 0] = 99999
mask[mask == 1] = 0
lamb = 0.2
# Find all variable to factor messages
var2fac = torch.multiply(beliefs, adj)
# BP algorithm
# Note that for an updated var2fac message, we can simply take the beliefs subtract away by the
# previous fac2var message for that variable. The local beliefs were
# initialized with d(e), which is the local factor for the variable. Hence, the total beliefs of the node
# is given by d(e) + all incoming messages (which are essentially damped updates). Thus, we can always just take
# the message on the edge out. Also note that if we were to convert this problem from MIN to MAX,
# we need to be careful of how we treat d(e). Right now, there's a general assumption
# that all updates are POSITIVE, because of the nature of d(e).
for rounds in range(bp_iter):
if rounds != 0:
var2fac = torch.multiply(beliefs, adj) - fac2var
# Compile message for every factor, we can use the adjacency matrix to mask away the msg
# To calculate the message of column_j (fac_j) to row_i (var_i), we take the var2fac messages,
# mask away row_i, take the min/max of column_j
fac2var_tilde = torch.zeros((batch_size, n, m), dtype=torch.float, device=beliefs.device)
masked_var2fac = var2fac + mask
for i in range(n):
neighbors = torch.cat((masked_var2fac[:, :i, :], masked_var2fac[:, (i + 1):, :]), dim=1)
fac2var_tilde[:, i, :] = torch.topk(neighbors, k=2, dim=1, largest=False)[0][:, 1, :]
fac2var_tilde = torch.multiply(fac2var_tilde, adj) * -1.0
eps = fac2var_tilde - fac2var
fac2var = fac2var + lamb * eps
beliefs = beliefs + torch.sum(eps, dim=2).unsqueeze(2)
return beliefs
```