Couldn't find the error of one of the variables needed for gradient computation has been modified by an inplace operation

Hello, I’m actually trying to implement a semi supervised CATGAN, and I’ve a problem with in place value replacement.

Here’s the detailed error while activating set_detect_anomaly

/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in ConvolutionBackward0. Traceback of forward call that caused the error:
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.8/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelapp.py", line 612, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.8/dist-packages/tornado/platform/asyncio.py", line 149, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
    self._run_once()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
    handle._run()
  File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py", line 690, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py", line 743, in _run_callback
    ret = callback()
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 787, in inner
    self.run()
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 381, in dispatch_queue
    yield self.process_one()
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 225, in wrapper
    runner = Runner(result, future, yielded)
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 714, in __init__
    self.run()
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 365, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 268, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelbase.py", line 543, in execute_request
    self.do_execute(
  File "/usr/local/lib/python3.8/dist-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/ipkernel.py", line 306, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 2854, in run_cell
    result = self._run_cell(
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 2881, in _run_cell
    return runner(coro)
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3057, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3249, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-12-5de9c8e2df0f>", line 32, in <module>
    batches_losses_tmp_G, batches_losses_tmp_D, conditional_entropies_real_tmp, marginal_entropies_real_tmp, cross_entropies_tmp, conditional_entropies_fake_tmp, marginal_entropies_fake_tmp=train_loop_fun1(train_data_loader, discriminator, generator, optimizer_G, optimizer_D, latent_size, TRAIN_BATCH_SIZE, device, λ)
  File "<ipython-input-5-dbf1d4be348c>", line 60, in train_loop_fun1
    y_fake = discriminator(fake_images)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "<ipython-input-8-8d93b18c1804>", line 28, in forward
    output = self.main(input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
  File "/usr/local/lib/python3.8/dist-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

My Discriminator:

class Discriminator(nn.Module):
    def __init__(self, dim=64, nb_channel=1):
        super(Discriminator, self).__init__()
        main = nn.Sequential(
            nn.Conv2d(nb_channel, dim, 4, 2, 3, bias=False),
            nn.BatchNorm2d(dim),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.5),#64x16x16
            nn.Conv2d(dim, 2 * dim, 4, 2, 1, bias=False),
            nn.BatchNorm2d(2*dim),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.5),#128x8x8
            nn.Conv2d(2 * dim, 4 * dim, 4, 2, 1, bias=False),
            nn.BatchNorm2d(4*dim),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.5),#256x4x4
            nn.Conv2d(4*dim, 4*dim, 4),
            nn.BatchNorm2d(4*dim),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.5),#256x1x1
            nn.Conv2d(4*dim, 10, 1)
        )

        self.main = main
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input):
        output = self.main(input)
        output = output.view(-1, 10)
        output = self.softmax(output)
        return output

Part of my code I’m suspecting the error :

for batch_idx, (data, targets, use_label, data_idx) in enumerate(data_loader):
...
  #uniform distribution sampling
        z = torch.randn(batch_size, latent_size, 1, 1).to(device=device)
        fake_images = generator(z)
        
        y_fake = discriminator(fake_images)

        conditional_entropy_fake = conditional_entropy(y_fake, batch_size)#maximize uncertainty

        loss_D = conditional_entropy_real - marginal_entropy_real - conditional_entropy_fake + (λ*cross_entropy)
        loss_D.backward(retain_graph=True)
        optimizer_D.step()
        
        
#         * Updating the Generator *


        # freeze the generator and update the Discriminator
        for p in discriminator.parameters():
            p.requires_grad = False  
        for p in generator.parameters():
            p.requires_grad = True  
        generator.zero_grad()
        marginal_entropy_fake = marginal_entropy(y_fake) #maximize uncertainty

        loss_G = conditional_entropy_fake - marginal_entropy_fake


        loss_G.backward(retain_graph=True)
         ...

Often users add the retain_graph=True argument to backward calls to avoid other errors, but are then introducing these inplace modification errors.
Could you explain why retain_graph=True is used in your code?

Hello @ptrblck, thank you for your answer, actually I was just testing it in order to solve my problem, but, it didn’t change anything, I keep getting the same error.

Here’s a notebook with my detailed code Google Colab

If you could help me, thanks in advance !

Hello @ptrblck, here’s the error I have while running it without retain_graph=True

Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

The error is raised since you are:

  • reusing conditional_entropy_fake which was already used in lossD and its .backward() call
  • reusing y_fake to compute marginal_entropy_fake which was already used to compute conditional_entropy_fake.

Both these issues will cause the runtime error since their corresponding computation graphs were already freed by the previously used backward call.