Akward error about backwarding through the graph a second time

Hi, I get a strange error in the second iteration after the first iteration worked fine (PyTorch 2.1.2):

Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Although I do not backward twice. with retain_graph=True it will tell me that different versions of parameters are after just one iteration.

import cv2

from settings import global_settings as gs
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.optim as optim

import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def read_values_from_file(theoretical_valuesX_path, theoretical_valuesY_path,
                          measured_valuesX_path, measured_valuesY_path):
    with open(theoretical_valuesX_path) as fileTX, \
            open(theoretical_valuesY_path) as fileTY, \
            open(measured_valuesX_path) as fileMX, \
            open(measured_valuesY_path) as fileMY:
        theoretical_valuesX = [float(x) for x in fileTX.read().replace("\n", " ")[:-1].split(" ")]
        theoretical_valuesY = [float(y) for y in fileTY.read().replace("\n", " ")[:-1].split(" ")]
        measured_valuesX = [float(x) for x in fileMX.read().replace("\n", " ")[:-1].split(" ")]
        measured_valuesY = [float(y) for y in fileMY.read().replace("\n", " ")[:-1].split(" ")]

        return theoretical_valuesX, theoretical_valuesY, measured_valuesX, measured_valuesY


if __name__ == "__main__":
    # load camera data
    dragonfly_focal = 70.339  # mm
    FOCAL_LENGTH_m = 1e-3 * dragonfly_focal
    pix_pitch_m = 1e-6 * gs.CAMERA_SETTINGS['PIXEL_PITCH']
    focal_pixels = FOCAL_LENGTH_m / pix_pitch_m

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # load dragonfly data
    # theoretical_valuesX_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\notebooks\simulated\theoretical_valuesX.txt"
    # theoretical_valuesY_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\notebooks\simulated\theoretical_valuesY.txt"
    # measured_valuesX_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\notebooks\simulated\measured_valuesX.txt"
    # measured_valuesY_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\notebooks\simulated\measured_valuesY.txt"

    theoretical_valuesX_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\data for calibration\theoretical_valuesX.txt"
    theoretical_valuesY_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\data for calibration/theoretical_valuesY.txt"
    measured_valuesX_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\data for calibration/measured_valuesX.txt"
    measured_valuesY_path = r"C:\Users\YOSSI\Desktop\PolarizedCameraCalib-scripts\data for calibration/measured_valuesY.txt"

    theoretical_values_cols, theoretical_values_rows, \
        measured_values_cols, measured_values_rows = read_values_from_file(
        theoretical_valuesX_path, theoretical_valuesY_path,
        measured_valuesX_path, measured_valuesY_path)

    # define world_points & points_on_image for optimization

    theoretical_valuesXYZ = np.stack(
        (np.array(theoretical_values_rows) / focal_pixels,
         np.array(theoretical_values_cols) / focal_pixels,
         np.ones_like(theoretical_values_rows)), axis=1)

    theoretical_valuesXY = np.stack(
        (np.array(theoretical_values_rows),
         np.array(theoretical_values_cols)), axis=1)

    # theoretical_valuesXYZ = np.stack((np.array(theoretical_valuesX) *
    #                                  pix_pitch_m / (dragonfly_focal * 0.001),
    #                                  np.array(theoretical_valuesY) *
    #                                  pix_pitch_m / (dragonfly_focal * 0.001),
    #                                  np.ones_like(theoretical_valuesX)),
    #                                  axis=1)
    # meters
    z0 = 500000

    # in meters
    world_pts = torch.tensor(theoretical_valuesXYZ * z0, device=device)

    pts_on_image = torch.tensor(np.stack((measured_values_rows, measured_values_cols),
                                         axis=1), device=device)

    # initial guess
    nx, ny = gs.CAMERA_SETTINGS['CHeight'], gs.CAMERA_SETTINGS['CWidth']

    init_focal = 10  # mm
    init_FOCAL_LENGTH_m = 1e-3 * init_focal
    pix_pitch_m = 1e-6 * gs.CAMERA_SETTINGS['PIXEL_PITCH']
    init_focal_pixels = init_FOCAL_LENGTH_m / pix_pitch_m

    # Suppose you have some initial guess for K, R, and distortion coefficients
    # For demonstration purposes, let's initialize them randomly

    fx_init = torch.tensor([init_focal_pixels],
                           dtype=torch.float64, requires_grad=False,
                           device=device)

    fy_init = torch.tensor([init_focal_pixels],
                           dtype=torch.float64, requires_grad=False,
                           device=device)

    cx_init = torch.tensor([0.],
                           dtype=torch.float64, requires_grad=False,
                           device=device)

    cy_init = torch.tensor([0.],
                           dtype=torch.float64, requires_grad=False,
                           device=device)

    alpha_init = torch.tensor([0.],
                           dtype=torch.float64, requires_grad=True,
                           device=device)

    beta_init = torch.tensor([0.],
                              dtype=torch.float64, requires_grad=True,
                              device=device)

    gamma_init = torch.tensor([0.],
                              dtype=torch.float64, requires_grad=True,
                              device=device)
    '''
    R_init = torch.stack([
        torch.cos(beta_init) * torch.cos(gamma_init),
        torch.sin(alpha_init) * torch.sin(beta_init) * torch.cos(gamma_init) -
        torch.cos(alpha_init) * torch.sin(gamma_init),
        torch.cos(alpha_init) * torch.sin(beta_init) * torch.cos(gamma_init) +
        torch.sin(alpha_init) * torch.sin(gamma_init),
        torch.sin(alpha_init) * torch.cos(beta_init),
        torch.sin(alpha_init) * torch.sin(beta_init) * torch.sin(gamma_init) +
        torch.cos(alpha_init) * torch.cos(gamma_init),
        torch.cos(alpha_init) * torch.sin(beta_init) * torch.sin(gamma_init) -
        torch.sin(alpha_init) * torch.cos(gamma_init),
        -torch.sin(alpha_init), torch.sin(alpha_init) * torch.cos(beta_init),
        torch.cos(alpha_init) * torch.cos(beta_init)
    ]).view(3, 3)
    '''
    R_init = torch.stack([torch.tensor([1], dtype=torch.float64),
                                torch.tensor([0], dtype=torch.float64),
                                torch.tensor([0], dtype=torch.float64),
                                torch.tensor([0], dtype=torch.float64),
                                torch.cos(beta_init), -torch.sin(beta_init),
                                torch.tensor([0], dtype=torch.float64),
                                torch.sin(beta_init),
                                torch.cos(beta_init)]).view(3, 3)

    #Rinv_init = R_init.t()

    distortion_init = torch.tensor([0, 0, 0], dtype=torch.float64, requires_grad=False, device=device)


    def project_points(world_points, focal_x, focal_y, cx, cy, R, z0):
        # Apply rotation (R) to world points
        #rotated_points = torch.matmul(R, world_points.t()).t()
        rotated_points = torch.matmul(world_points, R.t())

        # Apply intrinsic matrix (K)
        rotated_x, rotated_y, rotated_z = rotated_points.T

        projected_x = rotated_x * focal_x + cx

        projected_y = rotated_y * focal_y + cy

        projected_points = torch.stack((projected_x, projected_y, rotated_z), dim=-1)
        projected_points = projected_points[:, :2] / z0

        return projected_points


    def deproject_points(points_2d, fx, fy, cx, cy, Rinv, z0):

        Pinv = Rinv
        # Apply rotation (R) to world points
        homo_x, homo_y = points_2d.T

        homo_x = (homo_x - cx) / fx
        homo_y = (homo_y - cy) / fy

        homo_2d_points = torch.stack((homo_x, homo_y, torch.ones_like(homo_x)), dim=1)

        deprojected_points = torch.matmul(homo_2d_points, Pinv.t()) * z0

        return deprojected_points


    def undistort_points(distorted_points, f_x, f_y, cx, cy, dist_coeff):
        # center and scale the grid for radius calculation (distance from center of image)
        x = (distorted_points[:, 0] - cx) / f_x
        y = (distorted_points[:, 1] - cy) / f_y

        radius2 = x ** 2 + y ** 2  # distance from the center of image

        radial_distortion = 1 + dist_coeff[0] * radius2 + dist_coeff[1] * radius2 ** 2 + dist_coeff[
            2] * radius2 ** 3  # radial distortion model

        # apply the model
        x = x * radial_distortion
        y = y * radial_distortion

        # reset all the shifting
        x = x * f_x + cx
        y = y * f_y + cy

        return torch.stack((x, y), axis=1)


    def distort_points(projected_points, f_x, f_y, cx, cy, dist_coeff):
        # center and scale the grid for radius calculation (distance from center of image)
        x = (projected_points[:, 0] - cx) / f_x
        y = (projected_points[:, 1] - cy) / f_y

        radius2 = x ** 2 + y ** 2  # distance from the center of image

        radial_distortion = 1 + dist_coeff[0] * radius2 + dist_coeff[1] * radius2 ** 2 + dist_coeff[
            2] * radius2 ** 3  # radial distortion model

        # apply the model
        distorted_x = (x / radial_distortion) * f_x + cx
        distorted_y = (y / radial_distortion) * f_y + cy

        return torch.stack((distorted_x, distorted_y), axis=1)


    # Assuming you have some data and a loss function
    # You would calculate the reprojection error based on the current parameters

    def compute_loss(f_x, f_y, cx, cy, R, distortion, world_points, measured_2d_points, origXY, z0):

        projected_2d_points = project_points(world_points, f_x, f_y, cx, cy, R, z0)

        projection_loss = torch.mean(torch.norm(projected_2d_points - origXY, dim=1))

        #undistorted_measured_points = undistort_points(measured_2d_points, f_x, f_y, cx, cy, distortion)

        #undistortion_loss = torch.mean(torch.norm(undistorted_measured_points - origXY, dim=1))

        #deprojected_3d_points = deproject_points(undistorted_measured_points, f_x, f_y, cx, cy, R.t(), z0)

        # Compute the reprojection error (difference between projected and measured 2D points)
        #deprojection3d_loss = torch.mean(torch.norm(deprojected_3d_points - world_points, dim=1))

        #loss_R_Rinv_I = torch.norm(torch.eye(3, device=device) - R @ R.t(), p='fro')
        #loss_Rinv_R_I = torch.norm(torch.eye(3, device=device) - R.t() @ R, p='fro')
        #loss_Rinv_R_sym = torch.norm(R @ R.t() - R.t() @ R, p='fro')

        #R_ortogonality_loss1 = torch.norm(torch.matmul(R, R.t()) - torch.eye(3, device=device), p='fro')
        #R_ortogonality_loss2 = torch.norm(torch.matmul(R.t(), R) - torch.eye(3, device=device), p='fro')
        #return (projection_loss, deprojection3d_loss, undistortion_loss,
        #        loss_R_Rinv_I, loss_Rinv_R_I, loss_Rinv_R_sym,
        #        R_ortogonality_loss1, R_ortogonality_loss2)
        return projection_loss

    prev = 100000000000000

    preload = False

    if preload:
        final_fx = fx_init = torch.tensor(np.load("./fx.npy"), device=device,
                                          dtype=torch.float64,
                                          requires_grad=False)

        final_fy = fy_init = torch.tensor(np.load("./fy.npy"), device=device,
                                          dtype=torch.float64,
                                          requires_grad=False)

        final_cx = cx_init = torch.tensor(np.load("./cx.npy"), device=device,
                                          dtype=torch.float64,
                                          requires_grad=False)

        final_cy = cy_init = torch.tensor(np.load("./cy.npy"), device=device,
                                          dtype=torch.float64,
                                          requires_grad=False)

        final_alpha = alpha_init = torch.tensor(np.load("./alpha.npy"),
                                  dtype=torch.float64, requires_grad=True,
                                  device=device)

        final_beta = beta_init = torch.tensor(np.load("./beta.npy"),
                                  dtype=torch.float64, requires_grad=True,
                                  device=device)

        final_gamma = gamma_init = torch.tensor(np.load("./gamma.npy"),
                                  dtype=torch.float64, requires_grad=True,
                                  device=device)

        R_init = torch.stack([
            torch.cos(beta_init) * torch.cos(gamma_init),
            torch.sin(alpha_init) * torch.sin(beta_init) * torch.cos(gamma_init) -
            torch.cos(alpha_init) * torch.sin(gamma_init),
            torch.cos(alpha_init) * torch.sin(beta_init) * torch.cos(gamma_init) +
            torch.sin(alpha_init) * torch.sin(gamma_init),
            torch.sin(alpha_init) * torch.cos(beta_init),
            torch.sin(alpha_init) * torch.sin(beta_init) * torch.sin(gamma_init) +
            torch.cos(alpha_init) * torch.cos(gamma_init),
            torch.cos(alpha_init) * torch.sin(beta_init) * torch.sin(gamma_init) -
            torch.sin(alpha_init) * torch.cos(gamma_init),
            -torch.sin(alpha_init), torch.sin(alpha_init) * torch.cos(beta_init),
            torch.cos(alpha_init) * torch.cos(beta_init)
        ]).view(3, 3)

        final_distortion = distortion_init = torch.tensor(np.load("./dist.npy"), device=device,
                                                          dtype=torch.float64, requires_grad=False)

        prev = np.load("./deprojection_loss.npy")

    # Your optimization loop
    train = True
    if train:
        optimizer = optim.Adam([fx_init, fy_init, cx_init,
                                cy_init, alpha_init, beta_init, gamma_init,
                                distortion_init], lr=0.00000001)

        for epoch in range(1000000000):  # Choose your desired number of epochs
            optimizer.zero_grad()

            # Forward pass: compute predicted y by passing x to the model.
            '''
            (projection_loss, deprojection3d_loss, undistortion_loss,
             loss_R_Rinv_I, loss_Rinv_R_I, loss_Rinv_R_sym,
             R_orthogonality_loss1, R_orthogonality_loss2) = compute_loss(fx_init, fy_init, cx_init, cy_init,
                                                                          R_init, distortion_init, world_pts,
                                                                          pts_on_image,
                                                                          torch.tensor(theoretical_valuesXY), z0)
            '''
            projection_loss = compute_loss(fx_init, fy_init, cx_init, cy_init,
                                                                          R_init, distortion_init, world_pts,
                                                                          pts_on_image,
                                                                          torch.tensor(theoretical_valuesXY), z0)
            # Backward pass: compute gradient of the loss with respect to model parameters

            #total_loss = projection_loss #+ deprojection3d_loss# + \
                         #undistortion_loss + loss_R_Rinv_I + loss_Rinv_R_I + \
                         #loss_Rinv_R_sym + R_orthogonality_loss1 + \
                         #R_orthogonality_loss2
            total_loss = projection_loss #+ deprojection3d_loss + loss_R_Rinv_I + loss_Rinv_R_I + loss_Rinv_R_sym

            total_loss.backward()

            # Calling the step function on an Optimizer makes an update to its parameters
            optimizer.step()

            # Print loss for monitoring
            '''
            if epoch % 1000 == 0:
                print(f"Epoch {epoch}, Projection[px]: {projection_loss.item()}, "
                      f"Deprojection 3D[m]: {deprojection3d_loss.item()},"
                      f"undistortion Loss[p]: {undistortion_loss.item()}, "
                      f"R_Rinv_I [matrix norm2]: {loss_R_Rinv_I.item()},"
                      f"Rinv_R_I [matrix norm2]: {loss_Rinv_R_I.item()}, "
                      f"R Ortogonality 1 [matrix norm2]: {R_orthogonality_loss1.item()},"
                      f"R Ortogonality 2: {R_orthogonality_loss2.item()}")

            if deprojection3d_loss.item() + 0.001 < prev or epoch % 1000000 == 0:
                prev = deprojection3d_loss.item()
                final_fx = fx_init.cpu().detach().numpy()
                final_fy = fy_init.cpu().detach().numpy()
                final_cx = cx_init.cpu().detach().numpy()
                final_cy = cy_init.cpu().detach().numpy()
                final_distortion = distortion_init.cpu().detach().numpy()

                final_alpha = alpha_init.cpu().detach().numpy()
                final_beta = beta_init.cpu().detach().numpy()
                final_gamma = gamma_init.cpu().detach().numpy()

                np.save("./fx", final_fx)
                np.save("./fy", final_fy)
                np.save("./cx", final_cx)
                np.save("./cy", final_cy)

                np.save("./alpha", final_alpha)
                np.save("./beta", final_beta)
                np.save("./gamma", final_gamma)

                np.save("./dist", final_distortion)

                np.save("./deprojection_loss", prev)
        '''

Maybe this is fixed on more recent version of pytorch or maybe it is an incorrect use of pytorch (I suspect the gradients graph could explain this behavior), but I would be gratefull for another check of this.

Thank you in advance.

Here are the text input for reproducibility:

measured_valuesX.txt:

1.790000000000000000e+02 4.466677419354838889e+02 7.157034313725490620e+02 9.826348314606741496e+02 1.252000000000000000e+03 1.519000000000000000e+03 1.786000000000000000e+03 2.054448548812664740e+03 2.320000000000000000e+03 2.588000000000000000e+03 2.856000000000000000e+03 3.122394557823129162e+03 3.391558730158730214e+03 3.660000000000000000e+03 3.928000000000000000e+03
1.783841059602648897e+02 4.460000000000000000e+02 7.150000000000000000e+02 9.820000000000000000e+02 1.250710227272727252e+03 1.518000000000000000e+03 1.784570247933884275e+03 2.053000000000000000e+03 2.318000000000000000e+03 2.585764018691588717e+03 2.853686390532544465e+03 3.120000000000000000e+03 3.389000000000000000e+03 3.657000000000000000e+03 3.925000000000000000e+03
1.790000000000000000e+02 4.470000000000000000e+02 7.154166666666666288e+02 9.820000000000000000e+02 1.250766055045871553e+03 1.518000000000000000e+03 1.784439353099730397e+03 2.052706293706293764e+03 2.318000000000000000e+03 2.585000000000000000e+03 2.853000000000000000e+03 3.119000000000000000e+03 3.388000000000000000e+03 3.656000000000000000e+03 3.924000000000000000e+03
1.773841269841269934e+02 4.450000000000000000e+02 7.135790960451977298e+02 9.800000000000000000e+02 1.248698630136986367e+03 1.516000000000000000e+03 1.782000000000000000e+03 2.050234375000000000e+03 2.315330917874396164e+03 2.582648725212464342e+03 2.850703125000000000e+03 3.116541666666666970e+03 3.385000000000000000e+03 3.653000000000000000e+03 3.921000000000000000e+03
1.770000000000000000e+02 4.445899999999999750e+02 7.130000000000000000e+02 9.800000000000000000e+02 1.248260047281323978e+03 1.515326145552560547e+03 1.782000000000000000e+03 2.049684507042253699e+03 2.315000000000000000e+03 2.582000000000000000e+03 2.850000000000000000e+03 3.116000000000000000e+03 3.384496598639455442e+03 3.652468879668049794e+03 3.920000000000000000e+03
1.745999999999999943e+02 4.420000000000000000e+02 7.106624365482233543e+02 9.770000000000000000e+02 1.245719047619047615e+03 1.512592391304347757e+03 1.779000000000000000e+03 2.047000000000000000e+03 2.312000000000000000e+03 2.579000000000000000e+03 2.847000000000000000e+03 3.113000000000000000e+03 3.382000000000000000e+03 3.649608695652173992e+03 3.917000000000000000e+03
1.730000000000000000e+02 4.400000000000000000e+02 7.087180156657963153e+02 9.750000000000000000e+02 1.243733333333333348e+03 1.510679069767441888e+03 1.777000000000000000e+03 2.045000000000000000e+03 2.310313664596273156e+03 2.577334337349397629e+03 2.845000000000000000e+03 3.111000000000000000e+03 3.380000000000000000e+03 3.648000000000000000e+03 3.915000000000000000e+03
1.680000000000000000e+02 4.356163934426229503e+02 7.046632124352331630e+02 9.710000000000000000e+02 1.239651090342679026e+03 1.507000000000000000e+03 1.773000000000000000e+03 2.041247113163972244e+03 2.306695312500000000e+03 2.573632398753893995e+03 2.841580645161290249e+03 3.108000000000000000e+03 3.376452914798206166e+03 3.644468253968253975e+03 3.912000000000000000e+03
1.670000000000000000e+02 4.343503401360543990e+02 7.033076923076922640e+02 9.700000000000000000e+02 1.238649350649350708e+03 1.506000000000000000e+03 1.772313953488372135e+03 2.040523364485981347e+03 2.306000000000000000e+03 2.573000000000000000e+03 2.841000000000000000e+03 3.107000000000000000e+03 3.376000000000000000e+03 3.644000000000000000e+03 3.912000000000000000e+03
1.640000000000000000e+02 4.313355704697986539e+02 7.006685236768802270e+02 9.673115264797507962e+02 1.236000000000000000e+03 1.503331592689295121e+03 1.770000000000000000e+03 2.038405144694533874e+03 2.304000000000000000e+03 2.571000000000000000e+03 2.839498349834983401e+03 3.105607142857143117e+03 3.375000000000000000e+03 3.643000000000000000e+03 3.910509708737864003e+03
1.623787878787878753e+02 4.300000000000000000e+02 6.993916083916084290e+02 9.660000000000000000e+02 1.235000000000000000e+03 1.503000000000000000e+03 1.769384615384615472e+03 2.038000000000000000e+03 2.304000000000000000e+03 2.571000000000000000e+03 2.840000000000000000e+03 3.106000000000000000e+03 3.375000000000000000e+03 3.644000000000000000e+03 3.911000000000000000e+03

measured_valuesY.txt:

1.576443661971831034e+02 1.590000000000000000e+02 1.606911764705882320e+02 1.620000000000000000e+02 1.635333333333333314e+02 1.650000000000000000e+02 1.664658227848101149e+02 1.680000000000000000e+02 1.695231607629427799e+02 1.710000000000000000e+02 1.726549707602339083e+02 1.740000000000000000e+02 1.760000000000000000e+02 1.775714285714285836e+02 1.790000000000000000e+02
4.210000000000000000e+02 4.224172185430463742e+02 4.240000000000000000e+02 4.253750000000000000e+02 4.270000000000000000e+02 4.283445595854922203e+02 4.300000000000000000e+02 4.313507246376811395e+02 4.327063711911357586e+02 4.342850467289719631e+02 4.360000000000000000e+02 4.374790419161676596e+02 4.390000000000000000e+02 4.405978647686832801e+02 4.420000000000000000e+02
6.863542319749216176e+02 6.876644295302013461e+02 6.890000000000000000e+02 6.906535211267605519e+02 6.922591743119265857e+02 6.936194225721784505e+02 6.950000000000000000e+02 6.966993006993006929e+02 6.980000000000000000e+02 6.995261538461538748e+02 7.010000000000000000e+02 7.025934426229508745e+02 7.040000000000000000e+02 7.060000000000000000e+02 7.070000000000000000e+02
9.530000000000000000e+02 9.544969512195121979e+02 9.560000000000000000e+02 9.574369747899160075e+02 9.590000000000000000e+02 9.603728813559322361e+02 9.620000000000000000e+02 9.632477678571428896e+02 9.647560386473429617e+02 9.660000000000000000e+02 9.677161458333333712e+02 9.690000000000000000e+02 9.705416666666666288e+02 9.720000000000000000e+02 9.735719844357976172e+02
1.216000000000000000e+03 1.217000000000000000e+03 1.218658227848101205e+03 1.220000000000000000e+03 1.221702127659574444e+03 1.223000000000000000e+03 1.224409090909090992e+03 1.226000000000000000e+03 1.227341107871720169e+03 1.228606232294617485e+03 1.230000000000000000e+03 1.231510204081632764e+03 1.233000000000000000e+03 1.234000000000000000e+03 1.236000000000000000e+03
1.487000000000000000e+03 1.488000000000000000e+03 1.489723350253807212e+03 1.491000000000000000e+03 1.492678571428571331e+03 1.494000000000000000e+03 1.495422163588390504e+03 1.496706371191135759e+03 1.498000000000000000e+03 1.499495522388059726e+03 1.501000000000000000e+03 1.502000000000000000e+03 1.503535849056603865e+03 1.505000000000000000e+03 1.506381481481481387e+03
1.747468619246861863e+03 1.749000000000000000e+03 1.750323759791122711e+03 1.752000000000000000e+03 1.753271428571428487e+03 1.754760465116279192e+03 1.756000000000000000e+03 1.757462162162162258e+03 1.759000000000000000e+03 1.760000000000000000e+03 1.761489028213166193e+03 1.763000000000000000e+03 1.764000000000000000e+03 1.765388278388278422e+03 1.766588477366255120e+03
2.013393700787401485e+03 2.015000000000000000e+03 2.016259067357513004e+03 2.017681948424068651e+03 2.019000000000000000e+03 2.020553672316384109e+03 2.022000000000000000e+03 2.023249422632794449e+03 2.024674479166666515e+03 2.026000000000000000e+03 2.027000000000000000e+03 2.028405622489959796e+03 2.030000000000000000e+03 2.031000000000000000e+03 2.032000000000000000e+03
2.276396825396825534e+03 2.278000000000000000e+03 2.279269230769230489e+03 2.280697329376854668e+03 2.282000000000000000e+03 2.283547550432276694e+03 2.285000000000000000e+03 2.286000000000000000e+03 2.287502808988764173e+03 2.288664615384615445e+03 2.290000000000000000e+03 2.291000000000000000e+03 2.292374531835205744e+03 2.293500000000000000e+03 2.295000000000000000e+03
2.544404109589041127e+03 2.546000000000000000e+03 2.547278551532033362e+03 2.549000000000000000e+03 2.550347953216374208e+03 2.551731070496083703e+03 2.553000000000000000e+03 2.554000000000000000e+03 2.555469026548672446e+03 2.556616393442623121e+03 2.558000000000000000e+03 2.559000000000000000e+03 2.560000000000000000e+03 2.561000000000000000e+03 2.562000000000000000e+03
2.805000000000000000e+03 2.807000000000000000e+03 2.808000000000000000e+03 2.810000000000000000e+03 2.811000000000000000e+03 2.812496794871794918e+03 2.814000000000000000e+03 2.815000000000000000e+03 2.816335463258786149e+03 2.817449180327868817e+03 2.818588461538461615e+03 2.820000000000000000e+03 2.821000000000000000e+03 2.822000000000000000e+03 2.823000000000000000e+03

theoretical_valuesX.txt:

1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03
1.777584130093946442e+02 4.451238144521489630e+02 7.123517415740193428e+02 9.794651787522259383e+02 1.246487071062176256e+03 1.513440332172158378e+03 1.780347852215128569e+03 2.047232505643341028e+03 2.314117159071553488e+03 2.581024679114523678e+03 2.847977940224505801e+03 3.114999832534455891e+03 3.382113269712662714e+03 3.649341196834533093e+03 3.916706598277287412e+03

theoretical_valuesY.txt:

1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02 1.638777009419650312e+02
4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02 4.309911381201716267e+02
6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02 6.980130304301220576e+02
9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02 9.649662915401041801e+02
1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03 1.231873811583074257e+03
1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03 1.498758465011286717e+03
1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03 1.765643118439499176e+03
2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03 2.032550638482469367e+03
2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03 2.299503899592451489e+03
2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03 2.566525791902401579e+03
2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03 2.833639229080608402e+03

Hi Ilya!

It is true that you only have one line of code that calls .backward(), but
that line of code exists in a loop that runs multiple times. When the second
iteration of the loop runs, you call .backward() a second time, causing
the error.

beta_init is a trainable parameter (as it has requires_grad = True).

Outside (prior to) your training loop, you compute R_init in terms of
beta_init. This builds a piece of the computation graph that connects
the derived R_init to the computation-graph leaf (and trainable
parameter) beta_init. This piece of the computation graph is built only
once.

total_loss (which is projection_loss) depends on R_init. When
you call total_loss.backward(), you backpropagate through the entire
computation graph, freeing the graph along the way, including that part of
the graph that connects R_init to beta_init.

When you execute the second iteration of your loop, you rebuild the part
of the computation graph that gets built by compute_loss(). But you
don’t rebuild that part of the graph that connects R_init to beta_init.
When you call total_loss.backward() in your second iteration, it
fails with the reported error when it tries to backpropagate through the
R_initbeta_init part of the graph that was freed in the first iteration
(and not rebuilt).

I haven’t looked at the details of your code, but it might be possible to
fix your error by moving the computation of R_init inside of your training
loop so that that part of the graph also gets rebuilt.

Best.

K. Frank

@KFrank, moving it inside the training loop was exactly what I was considering, look like it may be a problem of only beta being a leaf.

That indeed solved it.