How do I backpropagate through a module's parameters?

I’m trying to implement meta-gradient learning similarly to [https://arxiv.org/pdf/1805.09801.pdf].
I need to backpropagate through a parameter update.

The following is a minimum example that highlights the problem:

                meta_param = torch.tensor(5.0,requires_grad=True) # this is the variable i want to learn
                param = torch.nn.Parameter(torch.tensor(4.0,requires_grad=True)) # this is a module param
                param.data = meta_param * 2

                meta_objective = - (param * 2)
                meta_objective.mean().backward()

I’m trying to get the gradient of meta_param wrt meta_objective, but it has to pass through a nn.Parameter update for param. If param was a regular tensor as so:

                meta_param = torch.tensor(5.0,requires_grad=True) # this is the variable i want to learn
                param = meta_param * 2

                meta_objective = - (param * 2)
                meta_objective.mean().backward()

meta_param.grad has a sensible value. But when param is an nn.Parameter, meta_param.grad is None.
When assigning the new value to the nn.Parameter, i notice it has no .grad_fn method, and it is a leaf_tensor=True for it.
This is true for all model weights i want to update. Is there a way to make an nn.Module’s weights differentiable? Perhaps making a temporary model and make its weights views of non-leaf tensors?
How do I work around this?

1 Like

you could look at something like https://github.com/facebookresearch/higher for this purpose.
It functionalizes the model, where it’s parameters can be detached and backproped through

4 Likes

oh my god yes! this is exactly what i was looking for and SO MUCH MORE
They even have differentiable optimizers that are 1000x more elegant than what I’d hacked together

For posterity, this is a minimal example of what I was trying to accomplish, and it works with higher:

import higher
import torch
from torch import nn

"""
Goal: use meta gradients to tune differentiable hyper params of an update step
by updating them to maximise a corss-validation step online against a second update
""" 
train_inputs = torch.tensor(10.0)
xval_inputs = torch.tensor(10.0)

meta_param = torch.tensor(10.0, requires_grad=True)
meta_optimizer = torch.optim.Adam([meta_param], lr=3e-4)

my_model = nn.Linear(1, 1)
opt_a = torch.optim.Adam(my_model.parameters(), lr=3e-4)

with higher.innerloop_ctx(my_model, opt_a) as (fmodel, diffopt):
    # training step with higher wrappers
    pred = fmodel(train_inputs.unsqueeze(0))
    pred = pred * meta_param
    loss = -pred
    diffopt.step(loss)
    meta_optimizer.zero_grad() # <- does nothing, but I had to be sure 
    
    # meta param update ste[
    second_pred = fmodel(xval_inputs.unsqueeze(0))
    meta_objective = -second_pred
    meta_objective.backward() #  meta_param.grad now exists, is valid and doesn't sum grads from first loss 🎉🎉 
    meta_optimizer.step()

thanks

4 Likes

A solution for your minimum example without involving external lib.

https://discuss.pytorch.org/t/non-leaf-variables-as-a-modules-parameters/65775