How to prevent leaf variable from moving into graph interior

Hi, I’m facing a bit of an issue in terms of running an algorithm on top of a computational graph. Some context, m_alpha_beta and m_beta_alpha are parameters that are a result from some neural network computation. I’m trying to run an algorithm on top of these parameters, followed by taking the loss of the output and backpropagating it to learn the network. At this stage, training is a problem because I’m getting an error of “RuntimeError: leaf variable has been moved into the graph interior”, which I suspect is due to in-place operations? My model will return m_alpha_beta_k and m_beta_alpha_k and I will calculate some cross-entropy loss based on it.

Is there any way to prevent this issue or to code it better?

def simp_min_sum_batch(self, m_alpha_beta, m_beta_alpha):
        n = m_alpha_beta.size(1)
        # Define message vectors
        # Each row belongs to one of the nodes, we have n alpha nodes and n beta nodes
        # For row i and column j, each entry will be the message vector of alpha_i to beta_j, M_{alpha_i -> beta_j} of length n,
        # where each entry of this vector is m_{alpha_i -> beta_j} (q)
        m_alpha_beta_k = torch.zeros((m_alpha_beta.size(0), m_alpha_beta.size(1), m_alpha_beta.size(2)),
                                     device=m_alpha_beta.device, requires_grad=True)
        m_beta_alpha_k = torch.zeros((m_beta_alpha.size(0), m_beta_alpha.size(1), m_beta_alpha.size(2)),
                                     device=m_beta_alpha.device, requires_grad=True)

        # Message passing
        for i in range(n):
            m_beta_alpha_k[:, :, i] = m_alpha_beta[:, i, :] - torch.max(
                torch.cat((m_alpha_beta[:, :i, :], m_alpha_beta[:, (i + 1):, :]), dim=1), dim=1)[0]
            m_alpha_beta_k[:, :, i] = m_alpha_beta[:, :, i] - torch.max(
                torch.cat((m_beta_alpha[:, :i, :], m_beta_alpha[:, (i + 1):, :]), dim=1), dim=1)[0]

        return m_alpha_beta_k, m_beta_alpha_k

the algorithm is hard to decipher, but in general your solutions are to either express it without a loop (so torch.zeros() buffers wouldn’t need to be explicitly created), or write parts to a python list and use torch.stack afterwards.

Hmm I’m not sure if I can remove the for loop though, as I’m trying to retrieve remove a certain row in the m_alpha_beta and m_beta_alpha matrix. Stacking sounds like it’s worth a try though, I probably could just create a tensor and keep stacking them?

I see your loop as

for i in range(n):
  ba_k[i] = f1(ab,ba,i)
  ab_k[i] = f2(ab,ba,i)

so it looks parallelizable (non-recurrent) to me… may be too complicated though…

You don’t have to create tensors, as (m_ab - torch.max(…)) expressions already allocate memory for partial results. You just keep them around in a list, and stack once after a loop.