Error in gradient computation because of an inplace operation

Hello everyone !

First i want to apologize for the bad layout of this post (it is my first time :slight_smile: ).

I am currently building an AI capable of learning the solution of a differential system (by penelizing the PDE and the boundary condition) and i have a problem during training : “one of the variables needed for gradient computation has been modified by an inplace operation”.

It happends during the training on the second call of the function “optimize” (in fact it bugs while trying to compute the second derivative of the outputs of the network with respect to the inputs).

Here is the code :

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import matplotlib.pyplot as plt



# We build a Network which is supposed to learn the solution of :
#           u'' - u + 2sin = 0    on Ω = [0; 5]
#           u(0) = 0   et   u(5) = sin(5)

# This system has a unique solution : 
#               u(x) = sin(x)


def exact(y):
    return torch.sin(y)

class Network(nn.Module):


    def __init__(self,dic={}):
        super(Network, self).__init__()
        self.layer1 = nn.Linear(1, 8)
        self.layer2 = nn.Linear(8, 16)
        self.layer3 = nn.Linear(16, 8)
        self.layer4 = nn.Linear(8, 1)


    def PDE_error(self,y):
        if not y.grad == None:
            y.grad.zero_()
        
        f_y = self.forward(y)
        
        n = len(y)
        f_y.backward(torch.tensor([[1]]*n),create_graph = True)
        df_y = torch.clone(y.grad)
        y.grad.zero_()
        
    
        df_y.backward((torch.tensor([[1]]*n)),create_graph=True)
        ddf_y = torch.clone(y.grad)
        y.grad.zero_()
        return torch.norm(ddf_y - f_y +2*torch.sin(y))**2/len(y)

    def boundary_error(self,y):
        return torch.norm(self.forward(y)-exact(y))**2/len(y)

    def forward(self,y):
        t = self.layer1(y)
        t = torch.tanh(t)
        t = self.layer2(t)
        t = torch.tanh(t)
        t = self.layer3(t)
        t = F.relu(t)
        f_y = self.layer4(t)
        return f_y
        
class Set():
    
    def __init__(self,n):
        self.interior = (5*torch.rand(n,1)).requires_grad_(True)
        self.boundary = torch.tensor([[0.],[5.]])

class PINN():
    
    def __init__(self,n):
        self.net = Network()
        self.set = Set(n)
        self.n = n
        self.optimiseur = optim.Adam(self.net.parameters())
    
    def delete_bug(self):
        self.set.interior = torch.clone(self.set.interior.detach_()).requires_grad_(True)
        
    def optimize(self):
        error = self.net.PDE_error(self.set.interior) + self.net.boundary_error(self.set.boundary)
        self.optimiseur.zero_grad()
        error.backward(retain_graph = True)
        self.optimiseur.step()
        # self.delete_bug()
        
    def train(self,n_iter):
        for k in range(n_iter):
            self.optimize()
            print(k)

so when i execute thoses lines :

torch.manual_seed(0)
IA = PINN(n=20)
IA.train(100)
plt.show()

x = torch.linspace(0,5,20).unsqueeze(1)
u = IA.net.forward(x)
plt.plot(np.array(x),u.detach().numpy(),'b')
plt.plot(np.linspace(0,5,20),np.sin(np.linspace(0,5,20)),'r')
plt.show()

I get the following error :

torch.manual_seed(0)
IA = PINN(n=20)
IA.train(100)
plt.show()

x = torch.linspace(0,5,20).unsqueeze(1)
u = IA.net.forward(x)
plt.plot(np.array(x),u.detach().numpy(),'b')
plt.plot(np.linspace(0,5,20),np.sin(np.linspace(0,5,20)),'r')
plt.show()
0
C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\torch\autograd\__init__.py:147: UserWarning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\spyder_kernels\console\__main__.py", line 23, in <module>
    start.main()
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\spyder_kernels\console\start.py", line 328, in main
    kernel.start()
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\ipykernel\kernelapp.py", line 677, in start
    self.io_loop.start()
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\tornado\platform\asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\asyncio\base_events.py", line 541, in run_forever
    self._run_once()
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\asyncio\base_events.py", line 1786, in _run_once
    handle._run()
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\asyncio\events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\ipykernel\kernelbase.py", line 471, in dispatch_queue
    await self.process_one()
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\ipykernel\kernelbase.py", line 460, in process_one
    await dispatch(*args)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\ipykernel\kernelbase.py", line 367, in dispatch_shell
    await result
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\ipykernel\kernelbase.py", line 662, in execute_request
    reply_content = await reply_content
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\ipykernel\ipkernel.py", line 360, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\ipykernel\zmqshell.py", line 532, in run_cell
    return super().run_cell(*args, **kwargs)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\IPython\core\interactiveshell.py", line 2915, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\IPython\core\interactiveshell.py", line 2960, in _run_cell
    return runner(coro)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\IPython\core\async_helpers.py", line 78, in _pseudo_sync_runner
    coro.send(None)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\IPython\core\interactiveshell.py", line 3186, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\IPython\core\interactiveshell.py", line 3377, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\IPython\core\interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\T0268083\AppData\Local\Temp\ipykernel_14232\3067292002.py", line 3, in <module>
    IA.train(100)
  File "C:\Users\T0268083\Pictures\MyApp\Spyder\ex_stovf.py", line 90, in train
    self.optimize()
  File "C:\Users\T0268083\Pictures\MyApp\Spyder\ex_stovf.py", line 82, in optimize
    error = self.net.PDE_error(self.set.interior) + self.net.boundary_error(self.set.boundary)
  File "C:\Users\T0268083\Pictures\MyApp\Spyder\ex_stovf.py", line 38, in PDE_error
    f_y = self.forward(y)
  File "C:\Users\T0268083\Pictures\MyApp\Spyder\ex_stovf.py", line 59, in forward
    t = self.layer3(t)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\torch\nn\modules\linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\torch\nn\functional.py", line 1753, in linear
    return torch._C._nn.linear(input, weight, bias)
 (Triggered internally at  ..\torch\csrc\autograd\python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
Traceback (most recent call last):

  File "C:\Users\T0268083\AppData\Local\Temp\ipykernel_14232\3067292002.py", line 3, in <module>
    IA.train(100)

  File "C:\Users\T0268083\Pictures\MyApp\Spyder\ex_stovf.py", line 90, in train
    self.optimize()

  File "C:\Users\T0268083\Pictures\MyApp\Spyder\ex_stovf.py", line 82, in optimize
    error = self.net.PDE_error(self.set.interior) + self.net.boundary_error(self.set.boundary)

  File "C:\Users\T0268083\Pictures\MyApp\Spyder\ex_stovf.py", line 46, in PDE_error
    df_y.backward((torch.tensor([[1]]*n)),create_graph=True)

  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\torch\tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

  File "C:\Users\T0268083\Pictures\MyApp\Anaconca\lib\site-packages\torch\autograd\__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [16, 8]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

As you may see, i fixed this issue by replacing the subset IA.set.interior with the function “delete_bug” (currently commented in the script) and it works if i use this function but when i use cuda and my GPU for training, this solution causes my graphic memory to slowly increase until it’s full. So I am looking for the real solution and not a workaround.

I am available for any additional information about my script and I can join you the “cuda-version” if you want to test it.

Sincerely,
Matthieu