Try to understand how to use optim.LBFGS

Hi all,

I want to use ‘optimiser = optim.LBFGS(model.parameters(), lr=1e-4)’ instead of ‘optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)’
but I didn’t know how to introduce ‘def closure():’
can someone please explain to me how to modify the following code to use optim.LBFGS

here is the code:

def train_model(model, trainDataset, valDataset, number_epochs):
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

  criterion = nn.MSELoss().to(device)
  hist = dict(train=[], val=[])

  best_loss = 25.0
  best_m_wts = copy.deepcopy(model.state_dict())

  for epoch in range(1, number_epochs + 1):
    model = model.train()

    train_losses = []
    for sequence_true in trainDataset:
      optimizer.zero_grad()

      sequence_true = sequence_true.to(device)
      sequence_pred = model(sequence_true)

      loss = criterion(sequence_pred, sequence_true)

      loss.backward()
      optimizer.step()

      train_losses.append(loss.item())

    val_losses = []
    model = model.eval()
    with torch.no_grad():
      for sequence_true in valDataset:

        sequence_true = sequence_true.to(device)
        sequence_pred = model(sequence_true)

        loss = criterion(sequence_pred, sequence_true)
        val_losses.append(loss.item())

    train_loss = np.mean(train_losses)
    val_loss = np.mean(val_losses)

    hist['train'].append(train_loss)
    hist['val'].append(val_loss)

    if val_loss < best_loss:
      best_loss = val_loss
      best_m_wts = copy.deepcopy(model.state_dict())

    print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')

  model.load_state_dict(best_m_wts)
  return model.eval(), hist

Thanks a lot for your help

Hi Py!

LBFGS is quite different in character than basically all of pytorch’s
other optimizers (which are all more-or-less purely gradient-descent
based). So you can’t use LBFGS as a plug-and-play replacement for,
say, Adam.

Rather than starting by rewriting some existing some code to use LBFGS,
I would recommend getting LBFGS working with a toy computation. The
following links to an example that I haven’t tried running, but appears to
contain the necessary ingredients: pytorch-L-BFGS-example.

If you still have questions, please post your attempt to use LBFGS on
some toy problem (providing a fully-self-contained, runnable script)
with specific questions about that attempt.

Good luck!

K. Frank

Hi @KFrank,

Thanks a lot for your help.
I’ll try to understand the example to see how it works.

By the way, I’ve tried rewriting the code to use torch.optim.LBFGS by following the example. Could you please tell me if what I’ve written is correct? At least to know if I’ve written the code correctly.

def train_model(model, trainDataset, valDataset, number_epochs):
  optimizer = torch.optim.LBFGS(model.parameters(), lr=1e-4)
  # change the criterion
  criterion = nn.MSELoss().to(device)
  hist = dict(train=[], val=[])

  best_loss = 25.0
  best_m_wts = copy.deepcopy(model.state_dict())

  # Add closure function because of torch.optim.LBFGS
    def closure():
        optimizer.zero_grad()
        train_losses = []

        model.train()
        for sequence_true in train_dataset:
            sequence_true = sequence_true.to(device)
            sequence_pred = model(sequence_true)
            loss = criterion(sequence_pred, sequence_true)
            loss.backward()
            train_losses.append(loss.item())

        return np.mean(train_losses)
		
  for epoch in range(1, number_epochs + 1):
    model = model.train()
    train_loss = optimizer.step(closure)

    val_losses = []
    model = model.eval()
    with torch.no_grad():
      for sequence_true in valDataset:

        sequence_true = sequence_true.to(device)
        sequence_pred = model(sequence_true)

        loss = criterion(sequence_pred, sequence_true)
        val_losses.append(loss.item())

    val_loss = np.mean(val_losses)

    hist['train'].append(train_loss)
    hist['val'].append(val_loss)

    if val_loss < best_loss:
      best_loss = val_loss
      best_m_wts = copy.deepcopy(model.state_dict())

    print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')

  model.load_state_dict(best_m_wts)
  return model.eval(), hist

On the other hand, I have a question about torch.optim.LBFGS. Do you have a link that explains the theory so that I can understand it a little more this optimiser? and understand what differentiates it from other optimisers?

Thanks a lot for your help

Hi Py!

Please post a super-simple, fully-self-contained, runnable script, together
with its output, and ask a specific question if there is something you don’t
understand or if something looks wrong.

First, take a look at the minFunc link in pytorch’s LBFGS documentation.
You can also find a good overview in Wikipedia’s LBFGS entry.

Most pytorch optimizers use some form of gradient descent. LBFGS also
uses gradients to choose a so-called search direction, but then searches
along that direction using the value of the loss function itself.

Best.

K. Frank

Hi @KFrank,

Thanks a lot for your answers and your help,
I entered both codes because I wanted only to understand how worked:

def closure():

why and when to use this function ? And what do I need to put in this function? and why this optimizer use this function?
I’ll now try to study a simpler example and try to understand this optimizer better and how to use closure().

Once again thank you for your help

Hi Py!

The short story: closure() basically packages the forward / backward
pass of the training loop that you would use for a “conventional” pytorch
optimizer such as SGD. But why?

You are optimizing (specifically minimizing) your loss as a function of
the parameters of your model.

With most of pytorch’s optimizers, it works like this:

Your code (by which I mean code that is not part of the optimizer) performs
a forward and backward pass. This populates the .grad properties of the
model’s parameters with the gradient of the loss (for the current value of
the model’s parameters).

You then call optimizer.step(). All this does is to change the model’s
parameters to some new values based on the gradient of the loss.
(Not that it matters that much, but the optimizer doesn’t care about
the value of loss, just its gradient.)

However, a single step of a LBFGS optimizer looks at the value of loss
(and its gradient) at multiple places in parameter space, and uses those
values to decide where else to jump around in parameter space to look
at more values of loss.

LBFGS takes responsibility for causing loss and its gradient to be
recomputed at each point in parameter space that it investigates, but
it doesn’t actually know how to perform this recomputation. So you
tell it how to perform the recomputation by passing a function – named
closure(), by convention – into the optimizer.step (closure)
call. It is closure()'s responsibility to return the value of loss and
populate the model’s .grad properties with the gradient (evaluated
for whatever values the model’s parameters have when closure()
is called). To be clear, LBFGS tells closure() what values of the model’s
parameters to use when closure() recomputes loss and .grad by
updating the model’s parameters before calling closure().

closure() normally does this by performing a single forward and backward
pass based on the current state of the model, returning the loss that was
computed by the forward pass and letting the backward pass populate
the .grads. (closure() doesn’t have to perform a forward / backward
pass – it can do whatever it wants as long as it returns some loss value
and somehow populates the .grads.)

Note, closure() will typically be called many times each time you call
optimizer.step().

So, unlike all the other pytorch optimizers (which are all basically
single-step gradient-descent optimizers), LBFGS takes much more
responsibility about how to search through parameter space while
optimizing loss. (For what it’s worth, its a quasi-Newton optimization
algorithm with line search.)

Best.

K. Frank

1 Like

Hi @KFrank,

Thanks a lot for your very clear answer, now I have a clear understanding.