Dynamic closure for LBFGS


I am fairly new to Pytorch so please excuse any errors.

I am trying to implement the NOTEARS algorithm from this paper - https://papers.nips.cc/paper/2018/file/e347c51419ffb23ca3fd5050202f9c3d-Paper.pdf

This paper uses the augmented Lagrangian method for solving the optimisation problem. I am using this implementation of LBFGS - GitHub - hjmshi/PyTorch-LBFGS: A PyTorch implementation of L-BFGS.. The key point is that I need to pass arguments to the closure() function. This works fine with a lambda pattern for some but not all of the variables. So from section 4.1 of the paper, I need to pass W, rho and alpha to the closure function as these are updated on the dual ascent. The lambda pattern works for W and rho but when I try to pass alpha as well I get a runtime error:

Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward first time.

It seems that the scalar alpha is turned into a tensor after the first iteration of LBFGS.
Anyone got any suggestions as to why that would happen to alpha and not rho?

""" This is a test to see if I can get a pytorch LBFGS
working. Just try simple linear regression. """ 

# Imports
import torch
import math
import gpytorch
import notears
import cdt 
import sys
sys.path.insert(1, '/home/sal/documents/Tests/torch_lbfgs/functions/')
import numpy as np
from matplotlib import pyplot as plt
from LBFGS import FullBatchLBFGS
from cdt.metrics import SHD
from cdt.metrics import SID

# Generate random DAG
num_nodes = 10
num_edges = 10
edge_coefficient_range = [0.5, 2.0]

true_adj_mat, _ = notears.utils.generate_random_dag(num_nodes, num_edges, edge_coefficient_range=edge_coefficient_range)

# Simulate data
n_sample = 100

X = notears.utils.simulate_linear_sem(true_adj_mat, n_sample, 'uniform').astype(np.float32)

# We want to define a least squares problem where we want to learn the identity function
n = num_nodes
m = n_sample
half_recip_m = 1/(2*m)

# Define variables for optimisation 
X = torch.from_numpy(X)

# Initialise W
W = torch.zeros((n, n), requires_grad=True)

# Define loss functions
def dag_loss(W, n):
    return torch.trace(torch.eye(n) + torch.multiply(W, W)) - n

def least_squares_loss(X, W, half_recip_m):
    return half_recip_m*torch.norm(X - torch.matmul(X, W), p="fro")
def loss_fn(X, W, half_recip_m, n, rho, alpha):
    ll_loss = least_squares_loss(X, W, half_recip_m)
    h_loss = dag_loss(W, n)
    loss =  ll_loss + 0.5*rho*torch.pow(h_loss, 2) + alpha*h_loss
    return loss

def closure(**kwargs):
    W = kwargs["W"]
    rho = kwargs["rho"]
    alpha = kwargs["alpha"]
    loss = loss_fn(X, W, half_recip_m, n, rho, alpha)
    return loss

def compute_W_star(X, W, half_recip_m, n, rho, alpha, inner_max_iters):
    loss = closure(W=W, rho=rho, alpha=alpha)
    for i in range(inner_max_iters):

        # Perform step and update curvature
        options = {'closure': lambda: closure(W=W, rho=rho, alpha=alpha),
                   'current_loss': loss, 'max_ls': 10}
        loss, _, lr, _, F_eval, G_eval, _, _ = optimizer.step(options)

    return W 

def dual_ascent(X, W, half_recip_m, n, rho, rho_multiplier,
                alpha, outer_max_iters, inner_max_iters, h_tol):

    for iteration in range(outer_max_iters):

        W_star = compute_W_star(X, W, half_recip_m, n, rho, alpha, inner_max_iters)
        h_star = dag_loss(W_star, n)
        ll_star = least_squares_loss(X, W_star, half_recip_m)

        print('Acyclicity loss: {}'.format(h_star))
        print('Least squares loss: {}'.format(ll_star))
        if h_star < h_tol:
            alpha = alpha + rho*h_star
            rho = rho*rho_multiplier
            W = W_star

    return W_star
# Initialise optimiser
optimizer = FullBatchLBFGS([W])

# Set parameters
lamb = 1e-8
rho = 1.0 
alpha = 0.1
h_tol = 1e-8
inner_max_iters = 20
outer_max_iters = 10
rho_multiplier = 10

W_sol = dual_ascent(X, W, half_recip_m, n, rho, rho_multiplier,
                    alpha, outer_max_iters, inner_max_iters, h_tol)