Unclear purpose of max_iter kwarg in the LBFGS optimizer

I am interested in understanding the precise purpose of the max_iter kwarg in the LBFGS optimizer. As far as I can tell from inspecting the code, it has no effect on the number of gradients that are used to construct the limited memory Hessian, and similarly it has no direct effect on the number of linesearch calls or any other internal workings of the LBFGS algorithm. It seems that the only function of max_iter is to group a certain number of LBFGS steps together within one outwardly exposed step of the Optimizer class.

If there is something more to max_iter I would be very grateful if someone could explain it.

If this is truly the only function of max_iter, then why is the default value set to 20? This seems to be a very strange choice (and a very strange functionality to have in general) when none of the first-order optimizers (Adam, SGD, etc) have anything similar to this, which implicitly makes their permanent value of max_iter set to 1. The effect of max_iter > 1 in LBFGS just makes the algorithm appear to run extremely slow (compared to the first-order methods), but have crazy good convergence relative to the number of optimizer steps (compared to the first-order methods). Of course, compared on equal footing (i.e. max_iter = 1) , LBFGS with no linesearch is only a little bit slower than Adam and converges only a little bit faster per step. I would think that it would be easier for users to understand the true differences in performance between the two algorithms if the default value of max_iter was set to 1 instead of 20.

Thanks!

1 Like

Hello,

i think i have figured it out:
I have set up the optimizer with history_size = 3 and max_iter = 1. After each optimizer.step() call you can print the optimizer state with print(optimizer.state[optimizer._params[0]]) and the length of the old directories which are taken into account in each iteration with print(len(optimizer.state[optimizer._params[0]]['old_dirs'])). The output of the latter is:
0
1
2
3
3
3

So the two loop recursion

              # iteration in L-BFGS loop collapsed to use just one buffer
                q = flat_grad.neg()
                for i in range(num_old - 1, -1, -1):
                    al[i] = old_stps[i].dot(q) * ro[i]
                    q.add_(old_dirs[i], alpha=-al[i])

                # multiply by initial Hessian
                # r/d is the final direction
                d = r = torch.mul(q, H_diag)
                for i in range(num_old):
                    be_i = old_dirs[i].dot(r) * ro[i]
                    r.add_(old_stps[i], alpha=al[i] - be_i)

is correct (see e.g. [Jorge Nocedal, Stephen Wright: Numerical Optimization, Algorithm 7.4]). This means that max_iter is really just the maximal number of iterations per optimization step and therefore max_iter >1 doesn’t have an effect. Moreover max_eval is also not really reasonable. Btw. after [Jorge Nocedal, Stephen Wright: Numerical Optimization] it is often sufficient to set history_size between 5-20 and the line search is necessary such that the BFGS update is stable.

1 Like

As you wrote, the existence of the max_iter kwarg is unusual. It turns out that it does not make a difference internally, and that the state of the optimizer is preserved across calls to optimizer.step(). Minimal working example to show this:

import torch

# rosenbrock function (a = 1, b = 100)
# https://en.wikipedia.org/wiki/Rosenbrock_function
def f(x): 
    return (1 - x[0])**2 + 100*(x[1] - x[0]**2)**2

# random initial iterate. Every re-run of this code will yield a different
# initial iterate
d = 30
xk = torch.rand(2,1) * d - d/2

# same initial iterate, different optimizers
xk1 = xk.clone()
xk1.requires_grad = True
xk2 = xk.clone()
xk2.requires_grad = True

# LBFGS parameters
max_iter = 1000
gnorm_tol = 1e-5
tolerance_change = 1e-5
# history_size must match max_iter to keep all search directions. Avoids the
# .pop() method internally
history_size = max_iter 

# max_iter steps, no external loop
optimizer1 = torch.optim.LBFGS([xk1],
                               max_iter = max_iter,
                               history_size = history_size,
                               tolerance_grad = gnorm_tol,
                               tolerance_change = tolerance_change,
                               line_search_fn = "strong_wolfe")

def closure1():
    optimizer1.zero_grad()
    y = f(xk1)
    y.backward()
    return y

# single step, yes external loop
optimizer2 = torch.optim.LBFGS([xk2],
                               max_iter = 1,
                               history_size = history_size,
                               tolerance_grad = gnorm_tol,
                               tolerance_change = tolerance_change,
                               line_search_fn = "strong_wolfe")

def closure2():
    optimizer2.zero_grad()
    y = f(xk2)
    y.backward()
    return y

# comparison
# optimizer1
optimizer1.step(closure1)

"""
For reference, here are some of the available keys for the state of the optimizer:
optimizer1.state[optimizer1._params[0]][<key string>]

Key strings and definitions
"d" : search direction at current iteration
"t" : step size at current iteration
"old_dirs" : differences in successive gradients up to history_size (y vectors)
"old_stps" : differences in successive iterates up to history_size (s vectors)
"ro" : 1 / (y.T @ s) at current iterate
"n_iter" : number of iterations so far
"""

# repeat the same for optimizer2 using the same number of iterations
num_iter = optimizer1.state[optimizer1._params[0]]['n_iter']

# optimizer2
for _ in range(num_iter):
    optimizer2.step(closure2)

# compare recorded differences in successive gradients (y vector) and
# recorded differences in successive steps (s vector)
y1 = optimizer1.state[optimizer1._params[0]]['old_dirs']
y2 = optimizer2.state[optimizer2._params[0]]['old_dirs']
s1 = optimizer1.state[optimizer1._params[0]]['old_stps']
s2 = optimizer2.state[optimizer2._params[0]]['old_stps']

ys_equal = all([torch.all(z1 == z2) for z1,z2 in zip(y1,y2)])
ss_equal = all([torch.all(z1 == z2) for z1,z2 in zip(s1,s2)])

print("y vectors equal? {}\ns vectors equal? {}".format(ys_equal,ss_equal))