Gradient computation has been modified by an inplace operation. There is no inplace operation in my code

I know there are some other posts about the same topic but I am still not able to find where is inplace in my code.
The error is

UserWarning: Error detected in ConvolutionBackward0. Traceback of forward call that caused the error:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/traitlets/config/application.py", line 1053, in launch_instance
    app.start()
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
    self._run_once()
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
    handle._run()
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue
    await self.process_one()
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 505, in process_one
    await dispatch(*args)
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell
    await result
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 740, in execute_request
    reply_content = await reply_content
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
    res = shell.run_cell(
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 546, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3024, in run_cell
    result = self._run_cell(
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3079, in _run_cell
    result = runner(coro)
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3284, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3466, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/packages/apps/jupyter/2023-10-09/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_447161/2607253225.py", line 2, in <module>
    main()
  File "/tmp/ipykernel_447161/1836616004.py", line 84, in main
    encoder_loss, decoder_loss, discriminator_loss = train_fn(train_loader, disc, encoder, decoder, opt_enc, opt_dec, opt_disc, mse, bce, BETA, GAMMA)
  File "/tmp/ipykernel_447161/1836616004.py", line 13, in train_fn
    stego = encoder(combined)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_447161/1439291926.py", line 65, in forward
    out = self.model(x)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 554, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/praddi/.local/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
    return F.conv2d(
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:110.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  0%|          | 0/375 [00:01<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[7], line 2
      1 torch.autograd.set_detect_anomaly(True)
----> 2 main()

Cell In[6], line 84, in main()
     76 train_loader = DataLoader(
     77     train_dataset,
     78     batch_size=BATCH_SIZE,
     79     shuffle=True,
     80     num_workers=NUM_WORKERS,
     81 )
     83 for i in range(NUM_EPOCHS):
---> 84     encoder_loss, decoder_loss, discriminator_loss = train_fn(train_loader, disc, encoder, decoder, opt_enc, opt_dec, opt_disc, mse, bce, BETA, GAMMA)
     85     disc_scheduler.step()
     86     enc_scheduler.step()

Cell In[6], line 41, in train_fn(loader, disc, encoder, decoder, opt_enc, opt_dec, opt_disc, mse, bce, beta, gamma)
     39         disloss = discover + disstego
     40         disc.zero_grad()
---> 41         disloss.backward(retain_graph=False)
     42         opt_disc.step()
     43 return total_encoder_loss / len(loader), total_decoder_loss / len(loader), total_discriminator_loss / len(loader)

File ~/.local/lib/python3.11/site-packages/torch/_tensor.py:581, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    571 if has_torch_function_unary(self):
    572     return handle_torch_function(
    573         Tensor.backward,
    574         (self,),
   (...)
    579         inputs=inputs,
    580     )
--> 581 torch.autograd.backward(
    582     self, gradient, retain_graph, create_graph, inputs=inputs
    583 )

File ~/.local/lib/python3.11/site-packages/torch/autograd/__init__.py:347, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    342     retain_graph = create_graph
    344 # The reason we repeat the same comment below is that
    345 # some Python versions print out the first line of a multi-line function
    346 # calls in the traceback and some print out the last line
--> 347 _engine_run_backward(
    348     tensors,
    349     grad_tensors_,
    350     retain_graph,
    351     create_graph,
    352     inputs,
    353     allow_unreachable=True,
    354     accumulate_grad=True,
    355 )

File ~/.local/lib/python3.11/site-packages/torch/autograd/graph.py:825, in _engine_run_backward(t_outputs, *args, **kwargs)
    823     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    824 try:
--> 825     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    826         t_outputs, *args, **kwargs
    827     )  # Calls into the C++ engine to run the backward pass
    828 finally:
    829     if attach_logging_hooks:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [3, 64, 7, 7]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The models are:

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout=0.5):
        super(ResNetBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels))

        # Shortcut connection (if necessary)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels))

    def forward(self, x):
        # Main path
        out = self.block(x)
        # Shortcut path
        shortcut = self.shortcut(x)
        # Add the main path and the shortcut path
        out += shortcut
        return out


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            #Downsampling
            nn.Conv2d(6, 64, 7, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, padding=1, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, padding=1, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            #9 resnet blocks
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            ResNetBlock(in_channels=256, out_channels=256),
            #Upsampling
            nn.ConvTranspose2d(256, 128, 4, padding=1, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 4, padding=1, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 7, padding=3),
            nn.Tanh())

    def forward(self, x):
        #print(x.shape)
        out = self.model(x)
        return out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.BatchNorm2d(3),
            nn.Sigmoid())

    def forward(self, x):
        out = self.model(x)
        return out

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Sigmoid())

    def forward(self, x):
        out = self.model(x)
        return out

The training function along with the main function:

def train_fn(loader, disc, encoder, decoder, opt_enc, opt_dec, opt_disc, mse, bce, beta, gamma):
    loop = tqdm(loader, leave=True)
    total_encoder_loss = 0.0
    total_decoder_loss = 0.0
    total_discriminator_loss = 0.0
    for idx, (cover, secret) in enumerate(loop):
        #print(idx)
        #print(idx,secret.shape)
        secret, cover = secret.to(DEVICE), cover.to(DEVICE)
        #display_image(secret)
        combined = torch.cat((secret,cover),1)
        #print(combined.shape)
        stego = encoder(combined)
        retrieve = decoder(stego)
        disc_result = disc(stego)
        #print(stego.shape)
        #print(retrieve.shape)

        encoder_mse = mse(stego, cover)
        decoder_mse = mse(retrieve, secret)
        gen_disc_loss = bce(disc_result, torch.ones(stego.size()).to(DEVICE))

        total_encoder_loss += encoder_mse.item()
        total_decoder_loss += decoder_mse.item()
        total_discriminator_loss += gen_disc_loss.item()

        loss = encoder_mse + beta*decoder_mse + gamma*gen_disc_loss
        opt_enc.zero_grad()
        opt_dec.zero_grad()
        loss.backward(retain_graph=True)
        opt_enc.step()
        opt_dec.step()

        if idx % 2 == 0:
            disc_cover = disc(cover).to(DEVICE)
            disc_stego = disc(stego).to(DEVICE)
            discover = bce(disc_cover, torch.ones(disc_cover.size()).to(DEVICE))
            disstego = bce(disc_stego, torch.zeros(disc_stego.size()).to(DEVICE))
            disloss = discover + disstego
            disc.zero_grad()
            disloss.backward(retain_graph=False)
            opt_disc.step()
    return total_encoder_loss / len(loader), total_decoder_loss / len(loader), total_discriminator_loss / len(loader)

def main():
    disc = Discriminator().to(DEVICE)
    encoder = Encoder().to(DEVICE)
    decoder = Decoder().to(DEVICE)
    opt_disc = optim.SGD(disc.parameters(), lr=LR_DISCRIMINATOR)
    opt_enc = optim.Adam(encoder.parameters(), lr=LR_ENCODER)
    opt_dec = optim.Adam(decoder.parameters(), lr=LR_DECODER)
    bce = nn.BCELoss().to(DEVICE)
    mse = nn.MSELoss().to(DEVICE)
    disc_scheduler = torch.optim.lr_scheduler.StepLR(opt_disc, step_size=4, gamma=0.9)
    enc_scheduler = torch.optim.lr_scheduler.StepLR(opt_enc, step_size=4, gamma=0.9)
    dec_scheduler = torch.optim.lr_scheduler.StepLR(opt_dec, step_size=4, gamma=0.9)

    #Load dataset
    train_dataset = MapDataset(root_dir_cover=TRAIN_DIR_COVER, root_dir_hidden=TRAIN_DIR_HIDDEN)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )

    for i in range(NUM_EPOCHS):
        encoder_loss, decoder_loss, discriminator_loss = train_fn(train_loader, disc, encoder, decoder, opt_enc, opt_dec, opt_disc, mse, bce, BETA, GAMMA)
        disc_scheduler.step()
        enc_scheduler.step()
        dec_scheduler.step()
        print(f'Epoch [{i}/{NUM_EPOCHS}] | Encoder Loss: {encoder_loss} | Decoder Loss: {decoder_loss} | Discriminator Loss: {discriminator_loss}')
        if i % 5 == 0:
            save_checkpoint(encoder, opt_enc, encoder_loss, filename=CHECKPOINT_ENC)
            save_checkpoint(decoder, opt_dec, decoder_loss, filename=CHECKPOINT_DEC)
            save_checkpoint(disc, opt_disc, discriminator_loss, filename=CHECKPOINT_DISC)

        if i % 10 == 0:
            with torch.no_grad():
                secret, cover = next(iter(train_loader))
                secret, cover = secret.to(DEVICE), cover.to(DEVICE)
                combined = torch.cat((cover, secret), 1)
                stego = encoder(combined).to(DEVICE)
                retrieve = decoder(stego).to(DEVICE)
                disc_result = disc(stego).to(DEVICE)
                display_image(cover, secret, stego, retrieve)

Please help

This code section will keep the computation graph alive from previous iterations while you are updating the parameters inplace with the step() call.
Because of this the gradients in previous iterations would be calculated in a wrong way (the parameters do not correspond to the original parameters anymore which were used to compute the original intermediate forward activations) and the code will thus fail.