Inplace error when using backward() function on discriminator

I got this error :

The problem occur in training loop when using backward() for discriminator not for generator

for epoch in range(0, EPOCH):
  print("EPOCH number = %d" % epoch)
  for x in iter(dataloader):
    with torch.autograd.set_detect_anomaly(True):

      G_x = model(x)
      
      # Generator train
      disc_fake_wav = wave_disc(G_x)
      disc_real_wav = wave_disc(x)

      disc_fake_stft = stft_disc(G_x)
      disc_real_stft = stft_disc(x)

      gen_loss = L_G(disc_fake_wav, disc_real_wav, disc_fake_stft, disc_real_stft, x, G_x)
      disc_loss = L_D(disc_fake_wav, disc_real_wav, disc_fake_stft, disc_real_stft)

      generator_optim.zero_grad()
      gen_loss.backward(retain_graph=True)
      generator_optim.step()

      # Discriminator train
      disc_fake_wav = wave_disc(G_x)
      disc_real_wav = wave_disc(x)

      disc_fake_stft = stft_disc(G_x)
      disc_real_stft = stft_disc(x)

      disc_loss = L_D(disc_fake_wav, disc_real_wav, disc_fake_stft, disc_real_stft)

      discriminator_optim.zero_grad()
      disc_loss.backward(retain_graph=True) # Error occurred here
      discriminator_optim.step()

      print("Training generator loss = %d" % gen_loss)
      print("Training discriminator loss = %d" % disc_loss)  

Full Traceback :

c:\Python\Python39\lib\site-packages\torch\autograd_init_.py:154: UserWarning: Error detected in MkldnnConvolutionBackward0. Traceback of forward call that caused the error:
File “c:\Python\Python39\lib\runpy.py”, line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File “c:\Python\Python39\lib\runpy.py”, line 87, in _run_code
exec(code, run_globals)
File “c:\Python\Python39\lib\site-packages\ipykernel_launcher.py”, line 16, in
app.launch_new_instance()
File “c:\Python\Python39\lib\site-packages\traitlets\config\application.py”, line 846, in launch_instance
app.start()
File “c:\Python\Python39\lib\site-packages\ipykernel\kernelapp.py”, line 677, in start
self.io_loop.start()
File “c:\Python\Python39\lib\site-packages\tornado\platform\asyncio.py”, line 199, in start
self.asyncio_loop.run_forever()
File “c:\Python\Python39\lib\asyncio\base_events.py”, line 596, in run_forever
self._run_once()
File “c:\Python\Python39\lib\asyncio\base_events.py”, line 1890, in _run_once
handle._run()
File “c:\Python\Python39\lib\asyncio\events.py”, line 80, in _run
self._context.run(self._callback, *self._args)
File “c:\Python\Python39\lib\site-packages\ipykernel\kernelbase.py”, line 457, in dispatch_queue
await self.process_one()
File “c:\Python\Python39\lib\site-packages\ipykernel\kernelbase.py”, line 446, in process_one
await dispatch(*args)
File “c:\Python\Python39\lib\site-packages\ipykernel\kernelbase.py”, line 353, in dispatch_shell
await result
File “c:\Python\Python39\lib\site-packages\ipykernel\kernelbase.py”, line 648, in execute_request
reply_content = await reply_content
File “c:\Python\Python39\lib\site-packages\ipykernel\ipkernel.py”, line 353, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File “c:\Python\Python39\lib\site-packages\ipykernel\zmqshell.py”, line 533, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File “c:\Python\Python39\lib\site-packages\IPython\core\interactiveshell.py”, line 2914, in run_cell
result = self._run_cell(
File “c:\Python\Python39\lib\site-packages\IPython\core\interactiveshell.py”, line 2960, in _run_cell
return runner(coro)
File “c:\Python\Python39\lib\site-packages\IPython\core\async_helpers.py”, line 78, in pseudo_sync_runner
coro.send(None)
File “c:\Python\Python39\lib\site-packages\IPython\core\interactiveshell.py”, line 3185, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File “c:\Python\Python39\lib\site-packages\IPython\core\interactiveshell.py”, line 3377, in run_ast_nodes
if (await self.run_code(code, result, async
=asy)):
File “c:\Python\Python39\lib\site-packages\IPython\core\interactiveshell.py”, line 3457, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File “C:\Users\ACHYUT~1\AppData\Local\Temp/ipykernel_11308/2867750421.py”, line 6, in
G_x = model(x)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “e:\Programming\Final Major\model.py”, line 43, in forward
return self.decode(frames)[:, :, :x.shape[-1]].clone()
File “e:\Programming\Final Major\model.py”, line 32, in decode
return self._decode_frame(encoded_frames)
File “e:\Programming\Final Major\model.py”, line 38, in _decode_frame
out = self.decoder(emb)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “e:\Programming\Final Major\decoder.py”, line 41, in forward
return self.layers(x)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\container.py”, line 141, in forward
input = module(input)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “e:\Programming\Final Major\conv1d.py”, line 46, in forward
return self.conv1d(x)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\conv.py”, line 301, in forward
return self._conv_forward(input, self.weight, self.bias)
File “c:\Python\Python39\lib\site-packages\torch\nn\modules\conv.py”, line 297, in _conv_forward
return F.conv1d(input, weight, bias, self.stride,
(Triggered internally at …\torch\csrc\autograd\python_anomaly_mode.cpp:104.)
Variable._execution_engine.run_backward(

Could you please try removing retain_graph=True from the line that gives the error?

I’m not sure why it’s needed there, could you explain?

I solved the problem by removing retain_graph and detaching the fake tensor while calculating
discriminator loss.
code snippet :-

      # Discriminator train
      disc_fake_wav = wave_disc(G_x.detach())
      disc_real_wav = wave_disc(x)

      disc_fake_stft = stft_disc(G_x.detach())
      disc_real_stft = stft_disc(x)