Disable "in-place" updates in troch.nn

In short, I want to enable “create_graph” when doing loss.backward() while using torch.nn layers, such that I can do param.backward() to get the gradient of the final weights w.r.t a hyper parameter.

In details, I am implementing an algorithm to solve a bilevel problem (two nested problems). One can look at the parameters optimised in the inner one as the weights of a torch.nn based model, while the parameters of the outer problem as the hyper parameters. I want to optimise both with gradient descent. Thus, to do one update on the hyper parameters, I need the gradient of the model’s weights (after being trained) with respect to these hyper parameters. This is because the loss function related to the hyper parameters optimisation is a function of the trained model weights.

The problem is that even when I set (create_graph = True) when I backward the inner loss, optimizer.step() performs in-place updates, so the graph cannot be created. Similarly when replacing optimizer.step() with manually doing the updates on the model weights, as it is still in-place updates:

        for name, param in model.named_parameters():
            param.data =param.data - param.grad

A simplified code of what I want to do:

for t in range(OUT_MAX_ITR):
    model.train()
    for i in range(IN_MAX_ITR):
        optimizer.zero_grad()
        outputs = model(xtr)
        loss = compute_loss(outputs)
        loss.backward(create_graph=True)
        optimizer.step()
 
    theta.grad = None
    out_loss = function_of_model_weights()
    out_loss.backward()
    update_theta(theta, theta.grad)

Here theta is the hyper parameter to be optimised.
Is there a way or a work around in torch to do that second order differentiation (or bilevel optimisation) when working with torch.nn?

Hi,

I don’t think the inplace is the main problem here (even though it might still be later on).

The main issue is that optimizers are not differentiable. So it is expected that you won’t be able to differentiate through them.
Writing it manually will work if you do not use .data (which you should never use) that is similar to a .detach(). Meaning that you explicitly ask autograd not to track this op.

Thank you for the reply.
So is there a way to manually iterate over the parameters of a model and update them without doing the param.data thing, and with ensuring that updates are being tracked? I.e. let torch construct the computational graph that links between the different updates, to do a backward() later on.

I don’t have problems giving up using the built-in optimizers, I just couldn’t find a way to for example iterate over the model as a dictionary (eg. keys-> layer_name, value-> layer_weights), and make keys refer to a different place in memory preserving a track to model’s updates.

If that is doable in torch, thank you for giving me a clue. If not, I would appreciate any suggestions, or alternative torch-based packages that can do that (like JAX with TF).

The main limitation here is that the parameters are “supposed” to remain parameters. So they cannot have history.

I think the simplest way to is to use the experimental API that you can find here: pytorch/_stateless.py at e4a9ee8d42449aced60660e9afdd5b8b1d0d29c5 · pytorch/pytorch · GitHub (you can install nightly build to be able to use it).

This would look like:

for t in range(OUT_MAX_ITR):
    model.train()
    # initially params are the actual parameters
    params = model.named_parameters()

    for i in range(IN_MAX_ITR):
        outputs = _stateles.functional_call(model, params, xtr)
        loss = compute_loss(outputs)
        loss.backward(create_graph=True, inputs=params.values())
        # Cannot use the plain optimizers here are they are not differentiable.
        for k, p in params.items():
            params[k] = p - p.grad
 
    out_loss = function_of_params(params)
    out_loss.backward()
    # Here model.parameters() will have their .grad field populated.
    # If you have any other leaf Tensor (not sure what theta is in your example)
   # they will also get their .grad field populated.

Take a look at this implementation. This is doing inner-loop optimization in a meta-learning package.