RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. A Problem About Autograd And Backward

Meet the error while trying function backward() in IRM.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [500, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
I’ve serached it on google only to find solutions about inplace operations like ‘+=’ . Adding a retain_graph=True also didn’t work.
Here’s the code:

# -*- coding: utf-8 -*-


import sys

import numpy as np
from PIL import Image
import operator as op
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torchvision import transforms
from torchvision import datasets
import torchvision.datasets.utils as dataset_utils

torch.autograd.set_detect_anomaly(True)

def calculate_penalty(loss,dummy_w):

    g1 = grad(loss[0::2].mean(), dummy_w, create_graph=True,retain_graph=True)[0]
    g2 = grad(loss[1::2].mean(), dummy_w, create_graph=True,retain_graph=True)[0]
    return ( g1 * g2 ).sum()


class IRM(nn.Module):
    def __init__(self):
        super(IRM,self).__init__()#28*28*3
        self.conv1=nn.Conv2d(3,40,5)#24*24*40
        self.pool1=nn.MaxPool2d(2)#12*12*40
        self.conv2=nn.Conv2d(40,70,5)#8*8*70
        self.pool2=nn.MaxPool2d(2)#4*4*70
        self.fc1=nn.Linear(4*4*70,500)
        self.fc2=nn.Linear(500,1)
    def forward(self,x):
        x1=F.relu(self.conv1(x))
        x2=self.pool1(x1)
        x3=F.relu(self.conv2(x2))
        x4=self.pool2(x3)
        x5=x4.view(-1,self.num_flat_features(x4))
        x6=F.relu(self.fc1(x5))
        x7=F.sigmoid(self.fc2(x6))
        x8=x7.flatten()
        return x8
    def num_flat_features(self,x):
        size=x.size()[1:]	
        num_features=1
        for s in size:
            num_features*=s
        return num_features

def calculate_acc(model,x,y):
    total=len(x)
    right=0
    y_get=model.forward(x)
    for i in range(total):
        if (y_get[i]<=0.5 and y[i]==0) or (y_get[i]>0.5 and y[i]==1):
            right+=1
    print(right,'/',total,' accuracy: ',right/total,seq='')
    return right/total
def train(model,optimizer,epoches,datafortrain,datafortest):
#data import begin
    x0=[]
    y0=[]
    for i in datafortrain[0]:
        x0+=i
    for i in datafortrain[1]:
        y0+=i
    for i in range(len(x0)):
        x0[i]=torch.from_numpy(np.array(x0[i])).permute(2, 0, 1).float()/255.0
        y0[i]=torch.tensor(float(y0[i]))
    x=torch.stack(x0)
    y=torch.stack(y0)

    dataready=torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(x, y), batch_size=30, shuffle=True)
#data import end
    dummy_w=torch.nn.Parameter(torch.Tensor([1.0]))##
    for epoch in range(epoches):
        penalty_weight=1.0
        error=0.
        penalty=0.
        for step,(x,y) in enumerate(dataready):
            optimizer.zero_grad()
            
            y_get=model.forward(x)
            
            #print(y_get)
            #print(y)
            loss = F.binary_cross_entropy_with_logits(y_get*dummy_w,y,reduction='none')
            #print(loss)
            #sys.exit()
            
            penalty=penalty+calculate_penalty(loss,dummy_w)
            
            error=error+loss.mean()
            
            (error+penalty_weight*penalty).backward(retain_graph=True)
            optimizer.step()
            
        a=calculate_acc(model,datafortrain[0][0],datafortrain[1][0])
        b=calculate_acc(model,datafortrain[0][1],datafortrain[1][1])
        c=calculate_acc(model,datafortest[0][0],datafortest[1][0])
        if a>0.8 and b>0.8 and c>0.5:
            print('done!')
            return
    print('\n')
def entrance():#this function is just about data import, not shown here
entrance()

I’m confused about that if I
replace penalty=penalty+calculate_penalty(loss,dummy_w) and error=error+loss.mean()
by penalty=calculate_penalty(loss,dummy_w) and error=loss.mean()
, the program could run successfully.
However, in the code from the original article (https://arxiv.org/pdf/1907.02893.pdf) ,such operator ‘+=’ doesn’t cause any problems.

Hi Whale!

This use of retain_graph = True is likely the immediate cause of your
problem.

This occurs within the loop over batches within an epoch. penalty
accumulates the result of calculate_penalty(), so it depends not
only on the current batch, but on previous batches as well.

When you call optimizer.step(), you modify the parameters of model
inplace. When you call .backward() the next time through the loop, you
.backward() not only through the graph that was just created, but through
the previous graph as well, because you explicitly retained it and because
penalty, having been accumulated over batches, depends on that previous
graph. But the previous graph depends on parameters of model that have
been modified (by optimizer.step()) after that graph was created.

You have to think about the logic of your program. When you call
.backward(), would it be enough to backward through just the current
result of calculate_penalty() or do you really need to backward (again)
through all of the previous batches that have been accumulated into
penalty?

This seems to make sense. In this version, penalty no longer contains
contributions from previous iterations of the loop over batches, so you
are no longer trying to backpropagate through graphs that depend on
model parameters that have since been modified.

However, whether this logic that “works” is actually appropriate for your
use case is something you have to decide.

As an aside:

        x7=F.sigmoid(self.fc2(x6))

You pass the result of the final Linear layer of your model through
sigmoid(). This converts the logits produced by Linear into
probabilities. But you then use binary_cross_entropy_with_logits(),
which expects logits rather than probabilities. You probably want to get
rid of the sigmoid().

Best.

K. Frank