When optimize same network several times, I met "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"

I met the error when run the code below. (The code actually does nothing but to reproduce the error.) The modules used when defining model come from github.

If the input of discriminator is detached, the error disappears. The error may result from update the parameter of encoder in the same iteration.

Can I update the encoder sereral times? If not, should I freeze the parameters of UNet when optimize discriminator?

# define model
class discriminator(nn.Module):
    def __init__(self, in_ch=1024) -> None:
        super().__init__()
        self.inc = (DoubleConv(in_ch, 64))
        self.e1 = Down(64, 128)
        self.out = nn.Conv2d(128, 1, 3, 2, 1)

    def forward(self, x):
        x = self.inc(x)
        x = self.e1(x)
        x = self.out(x)
        return x

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes=2, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))
        self.encoder = nn.Sequential(self.inc, self.down1, self.down2, self.down3, self.down4)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return {'logits': logits, 'mid': x5}

# run
def main():
    model = UNet(1).cuda()
    D1 = discriminator().cuda()
    D2 = discriminator().cuda()
    optM = optim.Adam(model.parameters(), lr=1e-3)
    optD1 = optim.Adam(list(D1.parameters()) + list(model.encoder.parameters()), lr=1e-3)
    optD2 = optim.Adam(list(D2.parameters()) + list(model.encoder.parameters()), lr=1e-3)

    for i in range(10):
        print("epoch: ", i)
        image = torch.randn(4, 1, 256, 256).cuda()
        out = model(image)

        f = out['mid']
        # detached "f" works.
        # d1 = D1(f.detach() * 0.99)
        # d2 = D2(f.detach() * 1.01)
        d1 = D1(f * 0.99)
        d2 = D2(f * 1.01)
        pred = nn.Softmax(1)(out['logits'])
        loss = -torch.log(pred).mean()
        ld = torch.abs(d1 - d2).mean()
        
        optM.zero_grad()
        optD1.zero_grad()
        optD2.zero_grad()

        loss.backward(retain_graph=True)
        optM.step()

        ld.backward(retain_graph=True) # error here
        optD1.step()
        optD2.step()

    return

Using retain_graph=True often creates these errors and is usually not needed.
Could you explain why you are setting this argument to True?

When retain_graph=False, the error below appear:
“RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.”

I have solved the error to some extent:

If both backward() functions executed before step(), the error disappear.

I wonder whether I can step() before backward() and the code can run.

        loss.backward(retain_graph=True)
        ld.backward(retain_graph=True)
        
        optM.step()
        optD1.step()
        optD2.step()

You won’t be able to update the parameters via optM.step() since it would also update the shared parameters used to calculate mid and would invalidate the stored intermediate forward activations which are needed to compute the gradients of all parameters from ld.backward().

Thanks for your reply.
Get it.