How do I backpropagate through a module's parameters?

I’m trying to implement meta-gradient learning similarly to [].
I need to backpropagate through a parameter update.

The following is a minimum example that highlights the problem:

                meta_param = torch.tensor(5.0,requires_grad=True) # this is the variable i want to learn
                param = torch.nn.Parameter(torch.tensor(4.0,requires_grad=True)) # this is a module param
       = meta_param * 2

                meta_objective = - (param * 2)

I’m trying to get the gradient of meta_param wrt meta_objective, but it has to pass through a nn.Parameter update for param. If param was a regular tensor as so:

                meta_param = torch.tensor(5.0,requires_grad=True) # this is the variable i want to learn
                param = meta_param * 2

                meta_objective = - (param * 2)

meta_param.grad has a sensible value. But when param is an nn.Parameter, meta_param.grad is None.
When assigning the new value to the nn.Parameter, i notice it has no .grad_fn method, and it is a leaf_tensor=True for it.
This is true for all model weights i want to update. Is there a way to make an nn.Module’s weights differentiable? Perhaps making a temporary model and make its weights views of non-leaf tensors?
How do I work around this?

1 Like

you could look at something like for this purpose.
It functionalizes the model, where it’s parameters can be detached and backproped through


oh my god yes! this is exactly what i was looking for and SO MUCH MORE
They even have differentiable optimizers that are 1000x more elegant than what I’d hacked together

For posterity, this is a minimal example of what I was trying to accomplish, and it works with higher:

import higher
import torch
from torch import nn

Goal: use meta gradients to tune differentiable hyper params of an update step
by updating them to maximise a corss-validation step online against a second update
train_inputs = torch.tensor(10.0)
xval_inputs = torch.tensor(10.0)

meta_param = torch.tensor(10.0, requires_grad=True)
meta_optimizer = torch.optim.Adam([meta_param], lr=3e-4)

my_model = nn.Linear(1, 1)
opt_a = torch.optim.Adam(my_model.parameters(), lr=3e-4)

with higher.innerloop_ctx(my_model, opt_a) as (fmodel, diffopt):
    # training step with higher wrappers
    pred = fmodel(train_inputs.unsqueeze(0))
    pred = pred * meta_param
    loss = -pred
    meta_optimizer.zero_grad() # <- does nothing, but I had to be sure 
    # meta param update ste[
    second_pred = fmodel(xval_inputs.unsqueeze(0))
    meta_objective = -second_pred
    meta_objective.backward() #  meta_param.grad now exists, is valid and doesn't sum grads from first loss 🎉🎉 



A solution for your minimum example without involving external lib.