Improving Function Performance

Hi folks, I’m seeking some help with my function below. I’d like to get some opinions on if there’s a way to improve the inference performance of my function.

Here’s some context of what I want to do. Beliefs is batch x n x m matrix, where each batch is basically a factor graph with n nodes (rows of the matrix) and m factors (columns of the matrix), where the graph is bipartite (n nodes on one side, m factors on the other, edges only from n to m and vice versa). From this, I calculate messages known as fac2var and var2fac which are factor to variable messages and variable to factor messages respectively. For factors, the message it receives is the 2nd lowest incoming message from the variable nodes (n nodes). As you can see, this is done via the for-loop in the message propagation. That section is extremely slow at the moment, as there’s an iterative for-loop that is going through.

Basically, I loop through all n nodes, and mask out the message it would be sending for each factor. Then, for every factor, I will find the 2nd lowest message it receives. This is done in such a fashion because for a factor m, the message it sends to n will be the 2nd lowest message it receives, NOT including the message it received from node n. Hence, I need to mask out the messages accordingly.

I was wondering if this can be done in a torch.scatter or torch.gather way, which could potentially speed it up a fair bit? I’m open to suggestions too!

Thanks a lot for the help.

    def constraint_bp(beliefs, adj, bp_iter):
        # For batch-level implementation, the beliefs is arranged as [batch_size, max_n, max_m]

        # We can construct a few matrices to handle the problem. We have 3 matrices, belief matrix (b x nx1),
        # adjacency matrix (b x n x m) and the factor->var matrix (b x n x m),
        # where n is the number of variable nodes, and m the number of factors
        batch_size, n, m = adj.size()
        fac2var = torch.zeros((batch_size, n, m), dtype=torch.float, device=beliefs.device)
        # Add a mask which inflates values by 99999 so that disconnected nodes are not considered in the message aggregation
        mask = adj.clone().detach()
        mask[mask == 0] = 99999
        mask[mask == 1] = 0

        lamb = 0.2
        # Find all variable to factor messages
        var2fac = torch.multiply(beliefs, adj)
        # BP algorithm

        # Note that for an updated var2fac message, we can simply take the beliefs subtract away by the
        # previous fac2var message for that variable. The local beliefs were
        # initialized with d(e), which is the local factor for the variable. Hence, the total beliefs of the node
        # is given by d(e) + all incoming messages (which are essentially damped updates). Thus, we can always just take
        # the message on the edge out. Also note that if we were to convert this problem from MIN to MAX,
        # we need to be careful of how we treat d(e). Right now, there's a general assumption
        # that all updates are POSITIVE, because of the nature of d(e).

        for rounds in range(bp_iter):
            if rounds != 0:
                var2fac = torch.multiply(beliefs, adj) - fac2var

            # Compile message for every factor, we can use the adjacency matrix to mask away the msg
            # To calculate the message of column_j (fac_j) to row_i (var_i), we take the var2fac messages,
            # mask away row_i, take the min/max of column_j
            fac2var_tilde = torch.zeros((batch_size, n, m), dtype=torch.float, device=beliefs.device)
            masked_var2fac = var2fac + mask
            for i in range(n):
                neighbors = torch.cat((masked_var2fac[:, :i, :], masked_var2fac[:, (i + 1):, :]), dim=1)
                fac2var_tilde[:, i, :] = torch.topk(neighbors, k=2, dim=1, largest=False)[0][:, 1, :]
            fac2var_tilde = torch.multiply(fac2var_tilde, adj) * -1.0

            eps = fac2var_tilde - fac2var
            fac2var = fac2var + lamb * eps
            beliefs = beliefs + torch.sum(eps, dim=2).unsqueeze(2)

        return beliefs