Train on MSE with transfomed neural network output

I would like to train a NN that miminizes the MSE between a transformation of the output of the NN, i.e. :

x_pred = NN(x)
y = f(x_pred)
loss_func(y_pred,x)

where f is a trasnformation of x_pred. I tried doing like this:

    # train the network
    for epoch in range(EPOCH):

        # for each training step
        for step, (batch_x, batch_y) in enumerate(loader):
            b_x = Variable(batch_x)
            b_y = Variable(batch_y)

            prediction = net(b_x)  # input x and predict based on x

            x_plus = integrate(sampling_interval, prediction, b_x)

            loss = loss_func(x_plus, b_y)  # must be (1. nn output, 2. target)

            optimizer.zero_grad()  # clear gradients for next train
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients
            print(loss.item())

where the function integrate (f) is

def integrate(sampling_interval, r, x0):
    M = 4  # RK4 steps per interval
    DT = sampling_interval / M
    X = x0
   
    for j in range(M):
        k1 = ode(X, r)
        k2 = ode(X + DT / 2 * k1, r)
        k3 = ode(X + DT / 2 * k2, r)
        k4 = ode(X + DT * k3, r)
        X = X + DT / 6 * (k1 + 2 * k2 + 2 * k3 + k4)

    return X


def ode(x0, r):
    qF = r[0, 0]
    qA = r[0, 1]
    qP = r[0, 2]
    mu = r[0, 3]

    FRU = x0[0, 0]
    AMC = x0[0, 1]
    PHB = x0[0, 2]
    TBM = x0[0, 3]

    fFRU = qF * TBM  
    fAMC = qA * TBM 
    fPHB = qP - mu * PHB
    fTBM = mu * TBM

    return torch.tensor([fFRU, fAMC, fPHB, fTBM])

But I get when I do this the MSE jumps from one number to the other and it does not converge. Am I making some mistakes?

Thanks a lot

I found the problem, I was losing the gradient at the end of the ode function, because I was creating a new tensor instead of simply concatenate the tensor like so:

    def ode(self, x0, r):
        qF = r[0, 0]
        qA = r[0, 1]
        qP = r[0, 2]
        mu = r[0, 3]

        FRU = x0[0, 0]
        AMC = x0[0, 1]
        PHB = x0[0, 2]
        TBM = x0[0, 3]

        fFRU = qF * TBM 
        fAMC = qA * TBM 
        fPHB = qP - mu * PHB
        fTBM = mu * TBM

        return torch.stack((fFRU, fAMC, fPHB, fTBM), 0)