Hi,
I am trying to optimize a custom loss function for a Graph Neural Network. Given an n by n adjacency matrix, I first calculate two network statistics: (1) count of edge, (2) count of 2-stars. Below function returns the 2 network statistics.
def network_stat(A):
n = A.shape[0]
holder = torch.zeros(2, dtype=torch.float32)
for i in range(n):
for j in range(n):
if i < j:
holder[0] += A[i,j] # first network statistics
for k in range(n):
if j < k:
holder[1] +=A[i,k] * A[j,k] # second network statistics
return(holder)
Then my custom loss function for a Graph Neural Network is
def loss_function(A_pred, eta):
# minimization: A_pred obtained from Graph Neural Network
A_pred = (A_pred>0.5).float() # convert to {0,1} adjacency matrix
predicted_network_stats = network_stat(A_pred)
loss = torch.dot(eta, sampled_network_stats) # (1 by 2) dot (2 by 1) returns scalar
return loss
In other words, I want to learn a Graph Neural Network, whose predicted graph A_pred will make the loss function minimized. However, I received below error and I am not sure about how to fix it.
one of the variables needed for gradient computation has been modified by an inplace operation
My understanding is my network_stat(A) function has broken the back-propagation of gradient. Can anyone shed some insight on this problem? I greatly appreciate your help.
Yes, I think your description is correct as network_stat
modifies holder
inplace in:
holder[0] += A[i,j] # first network statistics
# and
holder[1] +=A[i,k] * A[j,k] # second network statistics
You could try to replace this operation with its out-of-place version via:
holderA = holderA + A[i, j]
holderB = holderB + (A[i, k] * A[j, k])
...
holder = torch.cat((holderA, holderB), dim=0)
and check if this would solve the issue.
Hi,
Your solution solves the problem perfectly! Thank you so much again. However, my loss tend to bouncing around. I am wondering if there is anything I did wrong here as I am relatively new to Pytorch. I greatly appreciate any guidance and assistance!
This is my function to calculate 2 graph statistics, which is working fine now:
def network_stat(adj):
n = adj.shape[0]
holderA = torch.tensor([0],dtype=torch.float32)
holderB = torch.tensor([0],dtype=torch.float32)
for i in range(n):
for j in range(n):
if i < j:
holderA = holderA + adj[i,j]
for k in range(n):
if i < j and j < k:
holderB = holderB + (adj[i,k] * adj[j,k])
holder = torch.cat((holderA, holderB), dim=0)
return(holder)
This is my function to compute the loss, which depends on the statistics of predicted graph:
def loss_function(A_pred, eta):
neg_entropy = torch.log(A_pred).sum()
A_pred = (A_pred>0.5).float() # make float into 0 and 1 only
predicted_network_stats = network_stat(A_pred)
loss = -torch.dot(eta, predicted_network_stats) + neg_entropy
return loss
This is my codes to train the model, which involve a minimization (over GCN) within the maximization (over eta):
for max_epoch in range(20):
for min_epoch in range(50):
optimizer.zero_grad()
A_pred = model(feature, input)
loss = loss_function(A_pred, eta)
loss.backward()
optimizer.step()
with torch.no_grad():
A_pred = (A_pred>0.5).float()
expected_network_stats = network_stat(A_pred)
eta += lr * (observed_network_stats - expected_network_stats)
Thank you so much!!!
Good to hear it’s working now!
The threshold operation:
A_pred = (A_pred>0.5).float()
will detach the result from the computation graph so that A_pred
and predicted_network_stats
will not contribute to the gradient calculation and are just a constant offset to loss
.
Thank you very much for pointing out the issue! I have changed the (A_pred>0.5).float()
to torch.bernoulli(A_pred)
which I think the predicted_network_stats
will contribute to the gradient calculation.
By the way, since my function network_stat(adj)
involves multiple for-loops iterating over i,j
entry of the adjacency matrix, I am wondering if the output from this function will also contribute to the gradient calculation? Or the for-loops also interupt the computation graph for gradient?
Thank you very much again!
Loops do not detach the computation graph and you could check if the output tensor has a valid .grad_fn
attribute which indicates that it’s attached to a computation graph.