Replacing for loop for max operation

Hi, I need a bit of help replacing a for-loop. Here’s a code snippet

        for _ in range(n_iter):
            # Message passing
            for i in range(n):
                m_beta_alpha_k[:, :, i] = g[:, i, :] - torch.max(
                    torch.cat((m_alpha_beta[:, :i, :], m_alpha_beta[:, (i + 1):, :]), dim=1), dim=1)[0]
                m_alpha_beta_k[:, :, i] = g[:, :, i] - torch.max(
                    torch.cat((m_beta_alpha[:, :i, :], m_beta_alpha[:, (i + 1):, :]), dim=1), dim=1)[0]

            m_alpha_beta = m_alpha_beta_k
            m_beta_alpha = m_beta_alpha_k

Each tensor is a 3D tensor, where we have [batch, i, j], where [i, j] forms a matrix and we have a batch size of them to operate on. As shown in the loop, at each iteration, I wish to exclude a row of the [i, j] matrix, and take a max operation in each column for the remaining matrix. Is there a way to do this efficiently in torch?