RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768, 64]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead

Hi, I need your help in debugging the error.
As you can see in the title, I countered the RuntimeError during the backward process.
By following the suggestions such as .clone(), inplace=False in leaky_relu(), it always returned the same error.
Please see my code and let me know how should I fix it.

class MyModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(MyModel, self).__init__()
        
        # MLP
        self.converting_layer = nn.Sequential(
            nn.Linear(4, 64, bias=True),
            nn.LeakyReLU(), 
            nn.Linear(64, 256, bias=True)
        )
        
        # GAT
        self.conv1 = GATConv(in_channels, hidden_channels, heads=6) # in_channel, hidden_channel, heads
        self.lin1 = torch.nn.Linear(in_channels, hidden_channels * 6)
        self.norm1 = LayerNorm(hidden_channels * 6)
        self.conv2 = GATConv(hidden_channels * 6, out_channels, heads=1, concat=False) # hidden_channel * heads, out_channel, heads
        self.lin2 = torch.nn.Linear(hidden_channels * 6, out_channels)
        
        self._reset_parameters()

    def convert_data_size(self, data1, data2):
        x2 = torch.zeros((len(data1), 256), dtype=torch.float).to(device) #(107940, 256)
        x3 = self.converting_layer(data2).to(device) #(19392, 256)
        nonzero_index = torch.tensor(cID).to(device)
        x2.index_add_(0, nonzero_index, x3) #(107940, 256)
        
        return x2
    
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    
    def forward(self, data1, data2, neighbor_cl_ids, edge_index):
        x_dict, attn1_dict, attn2_dict = {}, {}, {}
        
        for i in range(len(data2)):
            data2[i] = self.convert_data_size(data1, data2[i]) #(107940, 256), requires_grad=true
            
        for e_i_class in torch.unique(neighbor_cl_ids):
            x = torch.cat((data1, data2[int(e_i_class.item())]), dim=1) #(107940, 512)
            x_1, a1 = self.conv1(x, edge_index, return_attention_weights=True) 
            x = F.leaky_relu(self.norm1(x_1))  #
            x = F.dropout(x, p = 0.2, training = self.training)
            x_2, a2 = self.conv2(x, edge_index, return_attention_weights=True)
            x = x_2 + self.lin2(x) #Skip connection
            
            x_dict[int(e_i_class.item())] = x
            attn1_dict[int(e_i_class.item())] = a1
            attn2_dict[int(e_i_class.item())] = a2
            
        return x_dict, attn1_dict, attn2_dict
def train(num_epochs):
    for epoch in range(1, num_epochs+1):
        tr_loss = []
        val_loss = []
        tr_losses = 0
        val_losses = 0
        tr_loss_sum = 0
        val_loss_sum = 0
        
        model.train()
        with torch.autograd.detect_anomaly():
            for data in tqdm(train_loader):
                
                data = data.to(device)
                data.edge_class = data.edge_class[data.input_id]
                data.edge_index_class = torch.zeros(len(data.edge_index[0])).to(device)    
                        
                for i in range(len(data.edge_index_class)):
                    if ((data.edge_index[1][i] in data.edge_label_index[0]) or (data.edge_index[1][i] in data.edge_label_index[1])):
                        data.edge_index_class[i] = data.edge_class[(data.edge_index[1][i]==data.edge_label_index[0])
                                                                     |(data.edge_index[1][i]==data.edge_label_index[1])]
                    else:
                        data.edge_index_class[i] = torch.max(data.edge_index_class[(data.edge_index[1][i]==data.edge_index[0])])
                        
                optimizer.zero_grad()
                z, a1, a2 = model(data.x[0], data.x[1], data.edge_index_class, data.edge_index) 
                
                for i in torch.unique(data.edge_class):
                    i = int(i.item())
                    tr_out = ((z[i][data.edge_label_index[0][data.edge_class==i]] * z[i][data.edge_label_index[1][data.edge_class==i]]).sum(dim=-1)).view(-1)
                    tr_loss.append(criterion(tr_out, data.edge_label[data.edge_class==i].float()))
                tr_loss_sum = sum(tr_loss) / len(tr_loss) #average loss
                tr_loss_sum.backward(retain_graph=True)
                optimizer.step()
                
                tr_losses = tr_losses + tr_loss_sum.item()
            avg_tr_loss = tr_losses/len(train_loader.dataset)

        model.eval()
        with torch.no_grad():
            y_val_pred, y_val_pred_prob, y_val_true = [], [], []
            for data in tqdm(val_loader):
                data = data.to(device)
                data.edge_class = data.edge_class[data.input_id]
                data.edge_index_class = torch.zeros(len(data.edge_index[0])).to(device)
            
                for i in range(len(data.edge_index_class)):
                    if ((data.edge_index[1][i] in data.edge_label_index[0]) or (data.edge_index[1][i] in data.edge_label_index[1])):
                        data.edge_index_class[i] = data.edge_class[(data.edge_index[1][i]==data.edge_label_index[0])
                                                                     |(data.edge_index[1][i]==data.edge_label_index[1])]
                    else:
                        data.edge_index_class[i] = torch.max(data.edge_index_class[(data.edge_index[1][i]==data.edge_index[0])])
                    
                y_val_true.append(data.edge_label)
                z, a1, a2 = model(data.x[0], data.x[1], data.edge_index_class, data.edge_index)
                
                for i in torch.unique(data.edge_class):
                    i = int(i.item())
                    val_out = ((z[i][data.edge_label_index[0][data.edge_class==i]] * z[i][data.edge_label_index[1][data.edge_class==i]]).sum(dim=-1)).view(-1)
                    val_out_sig = ((z[i][data.edge_label_index[0][data.edge_class==i]] * z[i][data.edge_label_index[1][data.edge_class==i]]).sum(dim=-1)).view(-1).sigmoid()
                    val_loss.append(criterion(val_out, data.edge_label[data.edge_class==i].float()))
                    
                    y_val_pred.append((out_sig>0.5).float().cpu())
                    y_val_pred_prob.append((out_sig).float().cpu())
                    
                y_val_pred_combined = torch.cat(y_val_pred, dim=0)
                y_val_pred_prob_combined = torch.cat(y_val_pred_prob, dim=0)
                val_loss_sum = sum(val_loss) / len(val_loss)
                val_losses = val_losses + val_loss_sum.item()

        avg_val_loss = val_losses/len(val_loader.dataset)
        y, pred, pred_prob = torch.cat(y_val_true, dim=0).cpu().numpy(), torch.cat(y_val_pred_combined, dim=0).cpu().numpy(), torch.cat(y_val_pred_prob_combined, dim=0).cpu().numpy()
        val_f1 = f1_score(y, pred) #average='micro'
        val_auc = roc_auc_score(y, pred_prob)
        val_aupr = average_precision_score(y, pred_prob)
        val_acc = accuracy_score(y, pred)
        
        print(f'Epoch: {epoch:03d}, Training Loss: {avg_tr_loss:.4f}, Validation Loss: {avg_val_loss:.4f} \n Validation AUC: {val_auc:.4f}, Validation AUPR: {val_aupr:.4f}, Validation F1-score: {val_f1:.4f}')

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [768, 64]], which is output 0 of AsStridedBackward0, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Note
This is part of the error message when I set with torch.autograd.detect_anomaly():.
I hope this message will be helpful to you.

File “/tmp/ipykernel_60238/3515189033.py”, line 28, in train
z, a1, a2 = model(data.x[0], data.x[1], data.edge_index_class, data.edge_index)
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/tmp/ipykernel_60238/2130339337.py”, line 52, in forward
x = x_2 + self.lin2(x) #error alert
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/linear.py”, line 114, in forward
return F.linear(input, self.weight, self.bias)
(Triggered internally at /opt/conda/conda-bld/pytorch_1695392022560/work/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

Using retain_graph=True in a backward call can easily cause these issues.
Could you explain why this argument is used?

1 Like

Hi, thank you for your reply!
I added that option because there was another error when I didn’t use it.

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I guess this error was caused by the backbone network in my model.
As you can see, my model gets edge_index for the message passing and two kinds of node features(i.e., data1, data2) for the GATConv layer for the different classes of edges.

def forward(self, data1, data2, neighbor_cl_ids, edge_index):
        x_dict, attn1_dict, attn2_dict = {}, {}, {}
        ...
        for e_i_class in torch.unique(neighbor_cl_ids):
            x = torch.cat((data1, data2[int(e_i_class.item())]), dim=1) #(107940, 512)
            x_1, a1 = self.conv1(x, edge_index, return_attention_weights=True) 
            x = F.leaky_relu(self.norm1(x_1))  
            x = F.dropout(x, p = 0.2, training = self.training)
            x_2, a2 = self.conv2(x, edge_index, return_attention_weights=True)
            x = x_2 + self.lin2(x) #Skip connection

            # Save the independent node embedding vectors per e_i_class
            x_dict[int(e_i_class.item())] = x
            attn1_dict[int(e_i_class.item())] = a1
            attn2_dict[int(e_i_class.item())] = a2
            
        return x_dict, attn1_dict, attn2_dict

Here is the example result of the x. I want to update each node embedding per e_i_class which can be more than 2.

#{0: tensor([[ 0.0909, -3.4048, -5.4494,  ..., -2.0168,  0.2386,  2.0413],
#         [-0.0409, -1.8523, -2.5657,  ...,  1.1624, -0.4525,  0.3905],
#         [-0.8994, -2.6931, -1.8342,  ...,  1.1342, -0.0207,  1.3258],
#         ...,
#         [-0.1011, -0.6471, -2.7655,  ..., -0.4696,  2.2361,  0.7389],
#         [ 0.5757, -0.7008, -3.3981,  ...,  0.8195,  0.4786,  0.6613],
#         [ 0.5942, -0.6595, -2.6022,  ...,  0.6340,  1.9474,  0.6223]],
#        device='cuda:0', grad_fn=<AddBackward0>),
# 1: tensor([[ 0.3415, -4.1216, -3.2260,  ...,  0.2468,  0.1036,  0.5828],
#         [ 1.3558, -2.1061, -2.2542,  ...,  1.1631,  0.8181,  2.1384],
#         [-0.4299, -2.2539, -2.2097,  ...,  1.1434,  0.1344,  0.5396],
#         ...,
#         [ 0.0113, -1.2283, -1.9238,  ..., -0.5057,  0.6364, -0.4368],
#         [ 0.2154,  0.2887, -2.2543,  ...,  1.2211,  1.2807,  1.7913],
#         [-0.3744, -0.5113, -2.3580,  ..., -0.5218,  2.4996,  1.4214]],
#        device='cuda:0', grad_fn=<AddBackward0>)}

Also, I suspect the loss calculation. Is it correct to calculate the combined loss optimization?

optimizer.zero_grad()
z, a1, a2 = model(data.x[0], data.x[1], data.edge_index_class, data.edge_index)
                
for i in torch.unique(data.edge_class):
     i = int(i.item())
     tr_out = ((z[i][data.edge_label_index[0][data.edge_class==i]] * z[i][data.edge_label_index[1][data.edge_class==i]]).sum(dim=-1)).view(-1)
     tr_loss.append(criterion(tr_out, data.edge_label[data.edge_class==i].float()))
tr_loss_sum = sum(tr_loss) / len(tr_loss) #average
tr_loss_sum.backward()
optimizer.step()

If my assumption is wrong, please correct it.
You see the detailed error message here.


File “/tmp/ipykernel_1955/2643827018.py”, line 50, in
train(num_epochs)
File “/tmp/ipykernel_1955/3582490242.py”, line 33, in train
tr_loss.append(criterion(tr_out, data.edge_label[data.edge_class==i].float()))
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/modules/loss.py”, line 725, in forward
return F.binary_cross_entropy_with_logits(input, target,
File “/scratch/r902a02/.conda/envs/PyG_scratch1/lib/python3.9/site-packages/torch/nn/functional.py”, line 3195, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
(Triggered internally at /opt/conda/conda-bld/pytorch_1695392022560/work/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass


Update
When I set one input dataset and ran the below code, it worked.

for data in tqdm(train_loader):
...
    optimizer.zero_grad()
    z, a1, a2 = model(data.x[0], data.x[1], data.edge_index_class, data.edge_index)
                
    for i in torch.unique(data.edge_class):
        i = int(i.item())
        tr_out = ((z[i][data.edge_label_index[0][data.edge_class==i]] * z[i][data.edge_label_index[1][data.edge_class==i]]).sum(dim=-1)).view(-1)
        tr_loss=criterion(tr_out, data.edge_label[data.edge_class==i].float())
                #tr_loss_sum = sum(tr_loss) / len(tr_loss)
    tr_loss.backward()
    optimizer.step()
                
    tr_losses += tr_loss.item()
avg_tr_loss = tr_losses/len(train_loader.dataset)

How can I deal with the problem when I use more than two datasets during loss calculation? I want to optimize the multiple losses and update the node embeddings per edge_class.

One of the differences is that your first approach appends the loss to tr_loss and I don’t see where this list is cleared or recreated. Appending the losses will try to backpropagate through all iterations and would cause the issue. Recreate tr_loss after the backward pass or check if you really want to keep the old computation graphs alive. In this case, you would need to delay the parameter update step.

1 Like

Thank you for your kind reply and sorry for the late response.
I solved the problem by initializing the loss-related variables as below:

...
optimizer.zero_grad()
z, a1, a2 = model(data.x[0], data.x[1], data.edge_index_class, data.edge_index)
for i in torch.unique(data.edge_class):
    tr_loss = []
    tr_loss_mean = 0
    i = int(i.item())
    
    tr_out = ((z[i][data.edge_label_index[0][data.edge_class==i]] * z[i][data.edge_label_index[1][data.edge_class==i]]).sum(dim=-1)).view(-1)
    tr_loss.append(criterion(tr_out, data.edge_label[data.edge_class==i].float()))
tr_loss_mean = sum(tr_loss)/len(tr_loss) 
tr_loss_mean.backward()
optimizer.step()
tr_losses += tr_loss_mean.item()
...

Since I’m not sure if it is correct, please let me know if my approach is wrong.
Once again, thank you for your solution and have a nice day!

In your current approach you are recreating the tr_loss in every iteration, thus deleting the previously stored tensors. In this case you wouldn’t need to store the loss in the first place and could just return the last one.
I guess you want to store the losses of all edge_classes, so reset the list before the for loop starts.

1 Like

Thank you so much for your advice!
Actually, I was encountering a problem where the model seemed to not learn at all.
After moving it to the front of the for loop, it looks well.
Nevertheless, the model still shows low evaluation metrics, I’ll check other suspicious parts.
Thank you for your kind feedback, again!