I’m trying to implement meta-gradient learning similarly to [https://arxiv.org/pdf/1805.09801.pdf].
I need to backpropagate through a parameter update.
The following is a minimum example that highlights the problem:
meta_param = torch.tensor(5.0,requires_grad=True) # this is the variable i want to learn
param = torch.nn.Parameter(torch.tensor(4.0,requires_grad=True)) # this is a module param
param.data = meta_param * 2
meta_objective = - (param * 2)
meta_objective.mean().backward()
I’m trying to get the gradient of meta_param wrt meta_objective, but it has to pass through a nn.Parameter update for param. If param was a regular tensor as so:
meta_param = torch.tensor(5.0,requires_grad=True) # this is the variable i want to learn
param = meta_param * 2
meta_objective = - (param * 2)
meta_objective.mean().backward()
meta_param.grad has a sensible value. But when param is an nn.Parameter, meta_param.grad is None.
When assigning the new value to the nn.Parameter, i notice it has no .grad_fn method, and it is a leaf_tensor=True for it.
This is true for all model weights i want to update. Is there a way to make an nn.Module’s weights differentiable? Perhaps making a temporary model and make its weights views of non-leaf tensors?
How do I work around this?
you could look at something like https://github.com/facebookresearch/higher for this purpose.
It functionalizes the model, where it’s parameters can be detached and backproped through
oh my god yes! this is exactly what i was looking for and SO MUCH MORE
They even have differentiable optimizers that are 1000x more elegant than what I’d hacked together
For posterity, this is a minimal example of what I was trying to accomplish, and it works with higher:
import higher
import torch
from torch import nn
"""
Goal: use meta gradients to tune differentiable hyper params of an update step
by updating them to maximise a corss-validation step online against a second update
"""
train_inputs = torch.tensor(10.0)
xval_inputs = torch.tensor(10.0)
meta_param = torch.tensor(10.0, requires_grad=True)
meta_optimizer = torch.optim.Adam([meta_param], lr=3e-4)
my_model = nn.Linear(1, 1)
opt_a = torch.optim.Adam(my_model.parameters(), lr=3e-4)
with higher.innerloop_ctx(my_model, opt_a) as (fmodel, diffopt):
# training step with higher wrappers
pred = fmodel(train_inputs.unsqueeze(0))
pred = pred * meta_param
loss = -pred
diffopt.step(loss)
meta_optimizer.zero_grad() # <- does nothing, but I had to be sure
# meta param update ste[
second_pred = fmodel(xval_inputs.unsqueeze(0))
meta_objective = -second_pred
meta_objective.backward() # meta_param.grad now exists, is valid and doesn't sum grads from first loss 🎉🎉
meta_optimizer.step()