Accessing state of differentiable optimizers of Higher

The differentiable optimizers of Higher (GitHub - facebookresearch/higher: higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.) don’t seem to have a state_dict() method like regular optimizers.

Can I still access, and save their states of somehow? I’d like to access them so I can load them to a regular optimizer later.

I’m not deeply familiar with higher, but it seems that the DifferentiableOptimizer base class accepts an already created optimizer object, so I would assume that you could still query the original optimizer.state_dict()? Could this be the case or are the original states stale?

1 Like

Well, the higher context looks like the following:

with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):
    for xs, ys in data:
        logits = fmodel(xs)  # modified `params` can also be passed as a kwarg
        loss = loss_function(logits, ys)  # no need to call loss.backwards()
        diffopt.step(loss)  # note that `step` must take `loss` as an argument

As you can see, within this context, we’ve a differentiable optimizer and a model.
Within the Higher context, we update only the state of fmodel and diffopt.
At the end of the Higher context, I just want to set my model to fmodel, and my optimizer to diffopt.

I can do model.load_state_dict(fmodel.state_dict()) for the model, but there doesn’t seem to be a state_dict() method for diffopt. This makes optimizing with stateful optimizers such as Adam a problem to me.