Can't run backpropagation through my function

Hi everyone I am new to PyTorch and I have a simple training loop:

for epoch in range(num_epochs):                   
    for i,data in enumerate(training_set):       
        output_layer = PIAE_model(data)
        loss = criterion(output_layer,data)

and the models methods look like this

def inverter(self,x):
        x = self.up(self.act(self.bn1(self.conv1(x))))     # convolution, batch normalization... 
        x = self.up(self.act(self.bn2(self.conv2(x))))    
        x = self.up((self.bn3(torch.abs(self.conv3(x)))))   
        return x

   def findif(self,vel_model):
       for n in range(0,self.nt):
                d2 = []
                sourceINJ = []
                for i in range(self.nx):  # for all the grid points                    
                    d2.append((self.p[i + 1].detach() - 2 * self.p[i].detach() + self.p[i - 1].detach()) / self.dx ** 2 ) 

                sourceINJ[1] = sourceINJ[1] + self.q[n]
                self.d2pdx2 = torch.stack(d2)
                self.p = torch.stack(sourceINJ)

           self.p = 2 * self.p.detach() + vel_model ** 2 * self.dt * self.d2pdx2.detach()
           self.traces[n] = self.p[1]
        return self.traces
    def forward(self, x):
        vel_model = self.inverter(x)
        seis_model = self.findif(vel_model)
        return seis_model

In the findif method I am using list append to avoid in place operations d2 list calculates the second spatial derivative and sourceINJ does self.p[1] = self.p[1]+self.q[n]. The variables self.p is a torch 0 tensor at first and after the first iteration becomes a functions of vel_model and self.d2pdx2 is a function of self.p. And self.dt is a scalar 1.

So the problem here lies in this line self.p = 2 * self.p.detach() + vel_model ** 2 * self.dt * self.d2pdx2.detach() in this form gradients are not computed properly and I get NaNs in my loss and setting torch.autograd.set_detect_anomaly(True) raises ‘Function ‘CudnnBatchNormBackward0’ returned nan values in its 0th output.’

replacing the plus sign in the formula with the multiplication 2 * self.p.detach() * vel_model ** 2 * self.dt * self.d2pdx2.detach() lets the training run smoothly. removing the self.d2pdx2 and the coefficient 2 such that self.p.detach() + vel_model ** 2 * self.dt also runs smoothly.

Please help me circumvent this problem.

I tried using list append as before

test = []
for i in range(self.nx):
    test.append(2 * self.p[i].detach() + vel_model[i] ** 2 * self.dt * self.d2pdx2[i].detach())
self.p = torch.stack(test)

But this didn’t help, I also tried cloning some tensors but no success.

Thank you for reading