One of the variables needed for gradient computation has been modified by an inplace operation (backward pass is performed twice)

I am running the code snippet

# Define torch NN module


class Net(Module):
    def __init__(self, qnn):
        super().__init__()
        self.conv1 = Conv2d(1, 2, kernel_size=5)
        self.conv2 = Conv2d(2, 16, kernel_size=5)
        self.dropout = Dropout2d()
        self.fc1 = Linear(256, 64)
        self.fc2 = Linear(64, 2)  # 2-dimensional input to QNN
        self.qnn = TorchConnector(qnn)  # Apply torch connector, weights chosen
        # uniformly at random from interval [-1,1].
        #self.fc3 = Linear(1, 1)  # 1-dimensional output from QNN

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x).clone()
        x = self.qnn(x)  # apply QNN
        #x = self.fc3(x)
        return cat((x, 1 - x), -1)


model4 = Net(qnn4)

# Define model, optimizer, and loss function
optimizer = optim.Adam(list(model4.parameters())[:-1], lr=0.001)


minimizer_args = dict(method='COBYLA', options={'disp':True, 'maxiter':5}, jac = False)
optimizer_qnn = MinimizeWrapper(list(model4.parameters()), minimizer_args)

loss_fun = CrossEntropyLoss()
# Start training
epochs = 10  # Set number of epochs
loss_list = []  # Store loss history
model4.train()  # Set model to training mode

for epoch in range(epochs):
    total_loss = []
    print('qnn params', list(model4.qnn.parameters()))
    with torch.autograd.set_detect_anomaly(True): 
        for batch_idx, (data, target) in enumerate(train_loader):

            output = model4(data) 
            loss = loss_fun(output, target)
            
            
            
            
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
        
            total_loss.append(loss.item()) 
        def closure():
            optimizer_qnn.zero_grad()
            loss = loss_fun(output, target)  
            loss.backward()
            print(loss.item()) 
            return loss 
        optimizer_qnn.step(closure)   
    print('qnn params after', list(model4.qnn.parameters())) 
    loss_list.append(sum(total_loss) / len(total_loss))
    print("Training [{:.0f}%]\tLoss: {:.4f}".format(100.0 * (epoch + 1) / epochs, loss_list[-1]))

and I ran the anomaly detection and am getting the in-place operation error at x = self.fc2(x).clone() this line. I added clone() to see if the error is rectified, but nothing changed. Can somebody help me?

I am also posting the traceback below:


C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\autograd\__init__.py:200: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.8_3.8.2800.0_x64__qbz5n2kfra8p0\lib\runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.8_3.8.2800.0_x64__qbz5n2kfra8p0\lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\traitlets\config\application.py", line 1043, in launch_instance
    app.start()
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel\kernelapp.py", line 725, in start
    self.io_loop.start()
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\tornado\platform\asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.8_3.8.2800.0_x64__qbz5n2kfra8p0\lib\asyncio\base_events.py", line 570, in run_forever
    self._run_once()
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.8_3.8.2800.0_x64__qbz5n2kfra8p0\lib\asyncio\base_events.py", line 1859, in _run_once
    handle._run()
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.8_3.8.2800.0_x64__qbz5n2kfra8p0\lib\asyncio\events.py", line 81, in _run
    self._context.run(self._callback, *self._args)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel\kernelbase.py", line 513, in dispatch_queue
    await self.process_one()
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel\kernelbase.py", line 502, in process_one
    await dispatch(*args)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel\kernelbase.py", line 409, in dispatch_shell
    await result
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel\kernelbase.py", line 729, in execute_request
    reply_content = await reply_content
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel\ipkernel.py", line 422, in do_execute
    res = shell.run_cell(
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\ipykernel\zmqshell.py", line 540, in run_cell
    return super().run_cell(*args, **kwargs)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\IPython\core\interactiveshell.py", line 3009, in run_cell
    result = self._run_cell(
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\IPython\core\interactiveshell.py", line 3064, in _run_cell
    result = runner(coro)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\IPython\core\interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\IPython\core\interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\Srushti\AppData\Local\Temp\ipykernel_68744\2922959437.py", line 20, in <module>
    output = model4(data)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\Srushti\AppData\Local\Temp\ipykernel_68744\1791620977.py", line 24, in forward
    x = self.fc2(x).clone()
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\Srushti\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.8_qbz5n2kfra8p0\LocalCache\local-packages\Python38\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
 (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\torch\csrc\autograd\python_anomaly_mode.cpp:119.)

Hi Srushti!

Your immediate problem is likely here. You are calling loss.backward()
with retain_graph = True. First, you should think carefully about whether
you need retain_graph = True, and if so, why.

In the last iteration of your enumerate (train_loader) loop, you build a
computation graph that connects output to the parameters of model4. This
graph is preserved. In general, optimizer.step() modifies the parameters
of your model inplace. When closure() is then executed, it calls backward()
on loss_fun (output, target), which will backpropagate through the graph
that connects output to the parameters of model4. But those parameters
have been modified inplace, causing the error.

The forward-call traceback generated by set_detect_anomaly (True) is
complaining about the call to fc2 (x) in your model4. This agrees with the
above analysis in that that fc2.weight has been modified inplace by the call
to optimizer.step().

To fix this, you will need to think through the logic of your use case. Does
optimizer_qnn.step (closure) need gradients of output with respect
to the “regular” parameters of model4, such as fc2? If so, would it be practical
to rebuild the model4output graph after calling optimizer.step(), perhaps
inside of closure()?

For some examples that show how to debug and fix inplace-modification errors,
see this post:

Good luck!

K. Frank