In short, I want to enable “create_graph” when doing loss.backward() while using torch.nn layers, such that I can do param.backward() to get the gradient of the final weights w.r.t a hyper parameter.
In details, I am implementing an algorithm to solve a bilevel problem (two nested problems). One can look at the parameters optimised in the inner one as the weights of a torch.nn based model, while the parameters of the outer problem as the hyper parameters. I want to optimise both with gradient descent. Thus, to do one update on the hyper parameters, I need the gradient of the model’s weights (after being trained) with respect to these hyper parameters. This is because the loss function related to the hyper parameters optimisation is a function of the trained model weights.
The problem is that even when I set (create_graph = True) when I backward the inner loss, optimizer.step() performs in-place updates, so the graph cannot be created. Similarly when replacing optimizer.step() with manually doing the updates on the model weights, as it is still in-place updates:
for name, param in model.named_parameters(): param.data =param.data - param.grad
A simplified code of what I want to do:
for t in range(OUT_MAX_ITR): model.train() for i in range(IN_MAX_ITR): optimizer.zero_grad() outputs = model(xtr) loss = compute_loss(outputs) loss.backward(create_graph=True) optimizer.step() theta.grad = None out_loss = function_of_model_weights() out_loss.backward() update_theta(theta, theta.grad)
Here theta is the hyper parameter to be optimised.
Is there a way or a work around in torch to do that second order differentiation (or bilevel optimisation) when working with torch.nn?