Custom loss function (gradient modified by inplace operation)

Hi,

I am trying to optimize a custom loss function for a Graph Neural Network. Given an n by n adjacency matrix, I first calculate two network statistics: (1) count of edge, (2) count of 2-stars. Below function returns the 2 network statistics.

def network_stat(A):
  n = A.shape[0]
  holder = torch.zeros(2, dtype=torch.float32)
  for i in range(n):
    for j in range(n):
      if i < j:
        holder[0] += A[i,j] # first network statistics
        for k in range(n):
          if j < k:
            holder[1] +=A[i,k] * A[j,k] # second network statistics
  return(holder)

Then my custom loss function for a Graph Neural Network is

def loss_function(A_pred, eta):
    # minimization: A_pred obtained from Graph Neural Network
    A_pred = (A_pred>0.5).float() # convert to {0,1} adjacency matrix
    predicted_network_stats = network_stat(A_pred)
    loss = torch.dot(eta, sampled_network_stats) # (1 by 2) dot (2 by 1) returns scalar
    return loss

In other words, I want to learn a Graph Neural Network, whose predicted graph A_pred will make the loss function minimized. However, I received below error and I am not sure about how to fix it.

one of the variables needed for gradient computation has been modified by an inplace operation

My understanding is my network_stat(A) function has broken the back-propagation of gradient. Can anyone shed some insight on this problem? I greatly appreciate your help.

Yes, I think your description is correct as network_stat modifies holder inplace in:

holder[0] += A[i,j] # first network statistics
# and
holder[1] +=A[i,k] * A[j,k] # second network statistics

You could try to replace this operation with its out-of-place version via:

holderA = holderA + A[i, j]
holderB = holderB + (A[i, k] * A[j, k])
...
holder = torch.cat((holderA, holderB), dim=0)

and check if this would solve the issue.

Hi,

Your solution solves the problem perfectly! Thank you so much again. However, my loss tend to bouncing around. I am wondering if there is anything I did wrong here as I am relatively new to Pytorch. I greatly appreciate any guidance and assistance!

This is my function to calculate 2 graph statistics, which is working fine now:

def network_stat(adj):
  n = adj.shape[0]
  holderA = torch.tensor([0],dtype=torch.float32)
  holderB = torch.tensor([0],dtype=torch.float32)
  for i in range(n):
    for j in range(n):
      if i < j:
        holderA = holderA + adj[i,j]
        for k in range(n):
          if i < j and j < k:
            holderB = holderB + (adj[i,k] * adj[j,k])
  holder = torch.cat((holderA, holderB), dim=0)
  return(holder)

This is my function to compute the loss, which depends on the statistics of predicted graph:

def loss_function(A_pred, eta):

    neg_entropy = torch.log(A_pred).sum()

    A_pred = (A_pred>0.5).float() # make float into 0 and 1 only
    predicted_network_stats = network_stat(A_pred)

    loss = -torch.dot(eta, predicted_network_stats) + neg_entropy
    return loss

This is my codes to train the model, which involve a minimization (over GCN) within the maximization (over eta):

for max_epoch in range(20): 
  for min_epoch in range(50):

    optimizer.zero_grad()
  
    A_pred = model(feature, input) 
    loss = loss_function(A_pred, eta)

    loss.backward()
    optimizer.step()  

  with torch.no_grad():
    A_pred = (A_pred>0.5).float()
    expected_network_stats = network_stat(A_pred)
    eta += lr * (observed_network_stats - expected_network_stats)

Thank you so much!!!

Good to hear it’s working now!
The threshold operation:

A_pred = (A_pred>0.5).float()

will detach the result from the computation graph so that A_pred and predicted_network_stats will not contribute to the gradient calculation and are just a constant offset to loss.

Thank you very much for pointing out the issue! I have changed the (A_pred>0.5).float() to torch.bernoulli(A_pred) which I think the predicted_network_stats will contribute to the gradient calculation.

By the way, since my function network_stat(adj) involves multiple for-loops iterating over i,j entry of the adjacency matrix, I am wondering if the output from this function will also contribute to the gradient calculation? Or the for-loops also interupt the computation graph for gradient?

Thank you very much again!

Loops do not detach the computation graph and you could check if the output tensor has a valid .grad_fn attribute which indicates that it’s attached to a computation graph.