In UNet3+: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

when i try to use the UNet3+ i get an error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 1024, 64, 64]], which is output 0 of MulBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

and my code is under here:

import torch
import torch.nn as nn


def make_divisible(value, divisor=8, min_value=None, min_ratio=0.9):
    if min_value is None:
        min_value = divisor
    new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
    if new_value < min_ratio * value:
        # new_value += divisor
        new_value = new_value+ divisor
    return new_value


class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv_1 = nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, stride=1, padding=1, bias=False)
        self.norm_1 = nn.BatchNorm2d(out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels,
                                kernel_size=3, stride=1, padding=1, bias=False)
        self.norm_2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.relu(self.norm_1(self.conv_1(x)))
        x = self.relu(self.norm_2(self.conv_2(x)))
        return x


class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, down=None, up=None):
        super(Decoder, self).__init__()
        layers = []
        if down:
            layers.append(nn.MaxPool2d(kernel_size=down, stride=down))
        elif up:
            layers.append(nn.Upsample(scale_factor=up, mode='bilinear'))
        layers.extend([
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ])
        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        x = self.decoder(x)
        return x

class UNet3P(nn.Module):
    def __init__(self, classes=10, in_channels=3, channel_ratio=1.0):
        super().__init__()
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder_1 = Encoder(in_channels, make_divisible(64 * channel_ratio))
        self.encoder_2 = Encoder(make_divisible(64 * channel_ratio), make_divisible(128 * channel_ratio))
        self.encoder_3 = Encoder(make_divisible(128 * channel_ratio), make_divisible(256 * channel_ratio))
        self.encoder_4 = Encoder(make_divisible(256 * channel_ratio), make_divisible(512 * channel_ratio))
        self.encoder_5 = Encoder(make_divisible(512 * channel_ratio), make_divisible(1024 * channel_ratio))

        self.decoder_4_1 = Decoder(make_divisible(64 * channel_ratio), make_divisible(64 * channel_ratio), down=8)
        self.decoder_4_2 = Decoder(make_divisible(128 * channel_ratio), make_divisible(64 * channel_ratio), down=4)
        self.decoder_4_3 = Decoder(make_divisible(256 * channel_ratio), make_divisible(64 * channel_ratio), down=2)
        self.decoder_4_4 = Decoder(make_divisible(512 * channel_ratio), make_divisible(64 * channel_ratio))
        self.decoder_4_5 = Decoder(make_divisible(1024 * channel_ratio), make_divisible(64 * channel_ratio), up=2)
        self.decoder_4_fusion = Decoder(make_divisible(320 * channel_ratio), make_divisible(320 * channel_ratio))

        self.decoder_3_1 = Decoder(make_divisible(64 * channel_ratio), make_divisible(64 * channel_ratio), down=4)
        self.decoder_3_2 = Decoder(make_divisible(128 * channel_ratio), make_divisible(64 * channel_ratio), down=2)
        self.decoder_3_3 = Decoder(make_divisible(256 * channel_ratio), make_divisible(64 * channel_ratio))
        self.decoder_3_4 = Decoder(make_divisible(320 * channel_ratio), make_divisible(64 * channel_ratio), up=2)
        self.decoder_3_5 = Decoder(make_divisible(1024 * channel_ratio), make_divisible(64 * channel_ratio), up=4)
        self.decoder_3_fusion = Decoder(make_divisible(320 * channel_ratio), make_divisible(320 * channel_ratio))

        self.decoder_2_1 = Decoder(make_divisible(64 * channel_ratio), make_divisible(64 * channel_ratio), down=2)
        self.decoder_2_2 = Decoder(make_divisible(128 * channel_ratio), make_divisible(64 * channel_ratio))
        self.decoder_2_3 = Decoder(make_divisible(320 * channel_ratio), make_divisible(64 * channel_ratio), up=2)
        self.decoder_2_4 = Decoder(make_divisible(320 * channel_ratio), make_divisible(64 * channel_ratio), up=4)
        self.decoder_2_5 = Decoder(make_divisible(1024 * channel_ratio), make_divisible(64 * channel_ratio), up=8)
        self.decoder_2_fusion = Decoder(make_divisible(320 * channel_ratio), make_divisible(320 * channel_ratio))

        self.decoder_1_1 = Decoder(make_divisible(64 * channel_ratio), make_divisible(64 * channel_ratio))
        self.decoder_1_2 = Decoder(make_divisible(320 * channel_ratio), make_divisible(64 * channel_ratio), up=2)
        self.decoder_1_3 = Decoder(make_divisible(320 * channel_ratio), make_divisible(64 * channel_ratio), up=4)
        self.decoder_1_4 = Decoder(make_divisible(320 * channel_ratio), make_divisible(64 * channel_ratio), up=8)
        self.decoder_1_5 = Decoder(make_divisible(1024 * channel_ratio), make_divisible(64 * channel_ratio), up=16)
        self.decoder_1_fusion = Decoder(make_divisible(320 * channel_ratio), make_divisible(320 * channel_ratio))

        self.final_layer_5 = nn.Sequential(
            nn.Conv2d(make_divisible(1024 * channel_ratio), classes,
                      kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=16, mode='bilinear')
        )
        self.final_layer_4 = nn.Sequential(
            nn.Conv2d(make_divisible(320 * channel_ratio), classes,
                      kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=8, mode='bilinear')
        )
        self.final_layer_3 = nn.Sequential(
            nn.Conv2d(make_divisible(320 * channel_ratio), classes,
                      kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=4, mode='bilinear')
        )
        self.final_layer_2 = nn.Sequential(
            nn.Conv2d(make_divisible(320 * channel_ratio), classes,
                      kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear')
        )
        self.final_layer_1 = nn.Sequential(
            nn.Conv2d(make_divisible(320 * channel_ratio), classes,
                      kernel_size=3, stride=1, padding=1)
        )

        self.cls = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Conv2d(make_divisible(1024 * channel_ratio), 2,
                      kernel_size=1, stride=1, padding=0),
            nn.AdaptiveMaxPool2d(1),
            nn.Sigmoid()
        )

        self.out_clock = nn.Conv2d(15, classes, kernel_size=1, stride=1)

    @staticmethod
    def _dot_product(seg, cls):
        b, c, h, w = seg.size()
        seg = seg.view(b, c, h * w)
        x = torch.einsum("ijk,ij->ijk", [seg, cls])
        x = x.view(b, c, h, w)
        return x

    def forward(self, x):
        e_1 = self.encoder_1(x)
        e_2 = self.encoder_2(self.down_sampling(e_1))
        e_3 = self.encoder_3(self.down_sampling(e_2))
        e_4 = self.encoder_4(self.down_sampling(e_3))
        e_5 = self.encoder_5(self.down_sampling(e_4))
        d_5 = e_5

        d_4_1 = self.decoder_4_1(e_1)
        d_4_2 = self.decoder_4_2(e_2)
        d_4_3 = self.decoder_4_3(e_3)
        d_4_4 = self.decoder_4_4(e_4)
        d_4_5 = self.decoder_4_5(d_5)
        d_4 = self.decoder_4_fusion(torch.cat((d_4_1, d_4_2, d_4_3, d_4_4, d_4_5), dim=1))

        d_3_1 = self.decoder_3_1(e_1)
        d_3_2 = self.decoder_3_2(e_2)
        d_3_3 = self.decoder_3_3(e_3)
        d_3_4 = self.decoder_3_4(d_4)
        d_3_5 = self.decoder_3_5(d_5)
        d_3 = self.decoder_3_fusion(torch.cat((d_3_1, d_3_2, d_3_3, d_3_4, d_3_5), dim=1))

        d_2_1 = self.decoder_2_1(e_1)
        d_2_2 = self.decoder_2_2(e_2)
        d_2_3 = self.decoder_2_3(d_3)
        d_2_4 = self.decoder_2_4(d_4)
        d_2_5 = self.decoder_2_5(d_5)
        d_2 = self.decoder_2_fusion(torch.cat((d_2_1, d_2_2, d_2_3, d_2_4, d_2_5), dim=1))

        d_1_1 = self.decoder_1_1(e_1)
        d_1_2 = self.decoder_1_2(d_2)
        d_1_3 = self.decoder_1_3(d_3)
        d_1_4 = self.decoder_1_4(d_4)
        d_1_5 = self.decoder_1_5(d_5)
        d_1 = self.decoder_1_fusion(torch.cat((d_1_1, d_1_2, d_1_3, d_1_4, d_1_5), dim=1))

        d_5_deep_supervision = self.final_layer_5(d_5)
        d_4_deep_supervision = self.final_layer_4(d_4)
        d_3_deep_supervision = self.final_layer_3(d_3)
        d_2_deep_supervision = self.final_layer_2(d_2)
        d_1_deep_supervision = self.final_layer_1(d_1)

        cls_branch = self.cls(d_5).squeeze(3).squeeze(2)
        cls_branch_max = cls_branch.argmax(dim=1)
        cls_branch_max = cls_branch_max.unsqueeze(1).float()

        x_1 = self._dot_product(d_1_deep_supervision, cls_branch_max)
        x_2 = self._dot_product(d_2_deep_supervision, cls_branch_max)
        x_3 = self._dot_product(d_3_deep_supervision, cls_branch_max)
        x_4 = self._dot_product(d_4_deep_supervision, cls_branch_max)
        x_5 = self._dot_product(d_5_deep_supervision, cls_branch_max)

        out = torch.cat([x_1, x_2, x_3, x_4, x_5], dim= 1)
        out = self.out_clock(out)
        # return x_5, x_4, x_3, x_2, x_1
        return out


if __name__ == '__main__':
    import torch.optim as optim
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    in_data = torch.randn(1, 3, 1024, 1024).to(device)  # b, c, h, w
    out_data = torch.randn(1, 3, 1024, 1024).to(device)  # b, c, h, w

    with torch.autograd.set_detect_anomaly(True):
        model = UNet3P(classes=3, in_channels=3, channel_ratio=1).to(device)
    loss_fc = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.1)

    pred_data = model(in_data)
    optimizer.zero_grad()
    with torch.autograd.set_detect_anomaly(True):
        loss = loss_fc(pred_data, out_data)
    loss.backward()
    print(out_data.size())

I have identified the issue in the code snippet you provided:

cls_branch = self.cls(d_5).squeeze(3).squeeze(2)

Upon reviewing the self.cls() method, it appears that the error is caused by the nn.Dropout(p=0.5, inplace=True). To fix this issue, you have two options:
1.Change inplace to False:

self.cls = nn.Sequential(
    nn.Dropout(p=0.5, inplace=False),  # Set inplace to False
    nn.Conv2d(make_divisible(1024 * channel_ratio), 2,
              kernel_size=1, stride=1, padding=0),
    nn.AdaptiveMaxPool2d(1),
    nn.Sigmoid()
)

2. Use the .clone() method:

min_d_5 = d_5.clone()  # Create a clone of the tensor
cls_branch = self.cls(min_d_5).squeeze(3).squeeze(2)

The problem is:I have used the inplace=True parameter, which means that the Dropout layer will operate directly on the input tensor, rather than creating a new one. This in-place operation can disrupt the gradient computation graph of the original tensor, as PyTorch’s automatic differentiation mechanism requires the integrity of the computation graph to be maintained.