Minimize equation result inside trainning loop

I am trying to minimize the results of some operations that involve the output of a model inside my training loop.
First I use the outputs of the model to compute a set of equations. Then I compute new values base on the equation results. I calculate the loss per these new values and sum it to send it to the backward ()
For some reason, I get the following error: element 0 of tensors does not require grad and does not have a grad_fn
Can anyone spot what I am doing wrong? I am doing a forbidden operation inside training loop?
Here is an example of what I am doing:


x = {'input data with multiple columns'}
y = {'targets'}
t = {'values useful for futur computation'}
class NeuralNetwork(nn.Module):
  def __init__(self):
    super(NeuralNetwork, self).__init__()
    self.linear_init = nn.Linear(in_features = 9, out_features = 20)
    self.linear1 = nn.Linear(in_features = 20, out_features = 3)
    self.linear2 = nn.Linear(in_features = 20, out_features = 3)
    self.linear3 = nn.Linear(in_features = 20, out_features = 3)

  def forward(self, x):
    x = torch.relu(self.linear_init(x))
    output1 = self.linear1(x)
    output2 = self.linear2(x)
    output3 = self.linear3(x)
    return output1, output2, output3

model = NeuralNetwork()

def formula1(param1,param2,parm3):
    resolve = {'equation that involves params and some columns of the input data'}
    return resolve

def formula2(param1,param2,parm3):
    resolve = {'equation that involves params and some columns of the input data'}
    return resolve

def formula3(param1,param2,parm3):
    resolve = {'equation that involves params and some columns of the input data'}
    return resolve

criterion = nn.MSELoss(reduction='mean')
optimiser = nn.optim.SGD(model.parameters(), lr=1e-5)

input_tensor = torch.from_numpy(x).type(torch.Tensor)
target_tensor = torch.from_numpy(y).type(torch.Tensor)
step =  torch.from_numpy(t).type(torch.Tensor)

n= len(input_tensor)

for t in range(10):
    val1=0
    val2=0
    val3=0
    res1 = torch.zeros(n, 1)
    res2 = torch.zeros(n, 1)
    res3 = torch.zeros(n, 1)
   
    
    out1, out2, out3 = model(input_tensor)
    
    inter1 =  formula1(out1[:,0], out2[:,0], out3[:,0])
    inter2 =  formula2(out1[:,1], out2[:,1], out3[:,1])
    inter3 =  formula3(out1[:,2], out2[:,2], out3[:,2])
    

    for i in range(len(input_tensor)):
        #there is a condition on the value when i=0 and these are really arbitrary operations, just an example!
            A=((step[i]-step[i-1])/(inter1[i]+inter1[i-1]))*0.3
            B=((step[i]-step[i-1])/(inter2[i]+inter2[i-1]))*0.3
            C=((step[i]-step[i-1])/(inter3[i]+inter3[i-1]))*0.3
            val1 = val1 + A
            val2 = val2 + B
            val3 = val3 + C
            
            res1[i] = val1
            res2[i] = val2
            res3[i] = val3
    
    loss1 = criterion(res1, 'a column of taget tensor') 
    loss2 = criterion(res2, 'a column of taget tensor')
    loss3 = criterion(res3, 'a column of taget tensor')
    
    cum_loss = loss1+loss2+loss3
    optimiser.zero_grad()
    cum_loss.backward()
    optimiser.step()

You are most likely detaching a tensor from the computation graph e.g. by recreating a tensor, calling detach() explicitly on a tensor, or by using a 3rd party library (such as numpy) for some operations.
Your code is unfortunately incomplete and not executable so I don’t know what exactly causes the issue.

Thank you for the quick response. I am not performing any detach() inside the training loop. I have tried using the trapezoid function of torch as the operation but still get the same error.

Sorry for the incomplete code. Here is the current code with the operations looks like

import torch.nn as nn
import torch
import pandas as pd

df_data = pd.read_csv('eff_evo.csv')

x = (df_data[['mat', 'sol', 'tex', 'pb4', 'cal_56', 'r_e', 'ttr', 'co', 'mix', 'pt15']]).to_numpy()
t = df_data['logging'].to_numpy().reshape(-1,1)
class NeuralNetwork(nn.Module):
        def __init__(self, input_size, hidden1, output_size):
            super(NeuralNetwork, self).__init__()
            self.input_size = input_size
            self.output_size = output_size
            self.hidden1= hidden1
            self.linear_init = nn.Linear(self.input_size, self.hidden1)
            self.linear1  = nn.Linear(self.hidden1, output_size)
            self.linear2 = nn.Linear(self.hidden1, output_size)
            self.linear3  = nn.Linear(self.hidden1, output_size)

        def forward(self, x):
            x = torch.relu( self.linear_init(x) )
            output1 = self.linear1(x)
            output2 = self.linear2(x)
            output3 = self.linear3(x)
    
            return output1, output2, output3

model = NeuralNetwork(x.shape[1],13, 3)

criterion = nn.MSELoss(reduction='mean')
optimiser = torch.optim.SGD(model.parameters(), lr=1e-5)

input_tensor = torch.from_numpy(x).type(torch.Tensor)
ste =  torch.from_numpy(t).type(torch.Tensor)
step = ste.reshape(-1)

n= len(step)
res1 = torch.zeros(n, 1)
res2 = torch.zeros(n, 1)
res3 = torch.zeros(n, 1)
for t in range(10):
  
    out1, out2, out3 = model(input_tensor)
    
    inter1 =  ((out1[:,0].reshape(-1,1))*(input_tensor[:,4].reshape(-1,1)))+((out2[:,0].reshape(-1,1))*(input_tensor[:,5].reshape(-1,1)))+((out3[:,0].reshape(-1,1))*(input_tensor[:,6].reshape(-1,1)))
    inter2 =  ((out1[:,1].reshape(-1,1))*(input_tensor[:,0].reshape(-1,1)))-((out2[:,1].reshape(-1,1))*(input_tensor[:,1].reshape(-1,1)))+((out3[:,1].reshape(-1,1))*(input_tensor[:,2].reshape(-1,1)))
    inter3 =  ((out1[:,2].reshape(-1,1))*(input_tensor[:,3].reshape(-1,1)))-((out2[:,2].reshape(-1,1))*(input_tensor[:,7].reshape(-1,1)))-((out3[:,2].reshape(-1,1))*(input_tensor[:,8].reshape(-1,1)))
    

    for i in range(len(input_tensor)):
        res1[i] = torch.trapezoid(inter1.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,0][0]
        res2[i] = torch.trapezoid(inter2.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,1][0]
        res3[i] = torch.trapezoid(inter3.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,2][0]
    
    loss1 = criterion(res1, input_tensor[:,0].reshape(-1,1)) 
    loss2 = criterion(res2, input_tensor[:,1].reshape(-1,1))
    loss3 = criterion(res3, input_tensor[:,2].reshape(-1,1))
    
    cum_loss = loss1+loss2+loss3
    
    print("Epoch ", t, "MSE: ", cum_loss.item())
    
    optimiser.zero_grad()
    cum_loss.backward()
    optimiser.step()

Can you spot what might be causing the issue?
Can the issue be that the different tensor res1, res2, res3 that I am trying to minimize are manually created and assigned values to?

No, I don’t see where issue might be caused and using random input tensors:

model = NeuralNetwork(10, 13, 3)
criterion = nn.MSELoss(reduction='mean')
optimiser = torch.optim.SGD(model.parameters(), lr=1e-5)

input_tensor = torch.randn(10, 10)
step =  torch.randn(10)

I get a different error:

# RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

which is caused by the inplace manipulation of resX.
Using:

    res1_tmp = []
    res2_tmp = []
    res3_tmp = []
    for i in range(len(input_tensor)):
        #res1[i] = torch.trapezoid(inter1.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,0][0]
        #res2[i] = torch.trapezoid(inter2.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,1][0]
        #res3[i] = torch.trapezoid(inter3.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,2][0]
        res1_tmp.append(torch.trapezoid(inter1.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,0][0])
        res2_tmp.append(torch.trapezoid(inter2.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,1][0])
        res3_tmp.append(torch.trapezoid(inter3.reshape(-1)[0:i + 1], x=step[0:i+1])+ input_tensor[:,2][0])
    res1 = torch.stack(res1_tmp).unsqueeze(1)
    res2 = torch.stack(res2_tmp).unsqueeze(1)
    res3 = torch.stack(res3_tmp).unsqueeze(1)

allows me to run the script properly with the random input tensors.

Thanks a lot! Employing the temps allowed the script to run properly