I wanted to implement the meta-lstm meta-learner in the paper OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING using higher but I found problems. I found that I cannot make it work without removing (what seems to be this crucial line):
to:
#self.param_groups = _copy.deepcopy(other.param_groups)
self.param_groups = other.param_groups
I provide an extremely simplified self-contained implementation of something similar here:
but I will copy paste here to keep the discussion in one place:
# base on the paper "OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING": https://openreview.net/pdf?id=rJY0-Kcll
class EmptySimpleMetaLstm(Optimizer):
def __init__(self, params, trainable_opt_model, trainable_opt_state, *args, **kwargs):
defaults = {
'trainable_opt_model':trainable_opt_model,
'trainable_opt_state':trainable_opt_state,
'args':args,
'kwargs':kwargs
}
super().__init__(params, defaults)
class SimpleMetaLstm(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
eta = self.param_groups[0]['trainable_opt_model']['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
# get gradient as "data"
g = g.detach() # gradients of gradients are not used (no hessians)
## very simplified version of meta-lstm meta-learner
input_metalstm = torch.stack([p, g, prev_lr.view(1,1)]).view(1,3) # [p, g, prev_lr] note it's missing loss, normalization etc. see original paper
lr = eta(input_metalstm).view(1)
fg = 1 - lr # learnable forget rate
## update suggested by meta-lstm meta-learner
p_new = fg*p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
higher.register_optim(EmptySimpleMetaLstm, SimpleMetaLstm)
def test_parametrized_inner_optimizer():
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
## training config
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
episodes = 5
nb_inner_train_steps = 5
## get base model
base_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1, bias=False)),
('relu', nn.ReLU())
]))
## parametrization/mdl for the inner optimizer
opt_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(3,1, bias=False)), # 3 inputs 1 for parameter, 1 for gradient, 1 for previous lr
('sigmoid', nn.Sigmoid())
]))
## get outer optimizer (not differentiable nor trainable)
outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
for episode in range(episodes):
## get fake support & query data (from a single task and 1 data point)
spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
## get differentiable & trainable (parametrized) inner optimizer
inner_opt = EmptySimpleMetaLstm(base_mdl.parameters(), trainable_opt_model={'eta': opt_mdl}, trainable_opt_state={'prev_lr': 0.9*torch.randn(1)})
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small 5)
fmodel.train()
# base/child model forward pass
inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
# inner-opt update
diffopt.step(inner_loss)
## Evaluate on query set for current task
qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
qry_loss.backward() # for memory efficient computation
## outer update
print(f'episode = {episode}')
print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
outer_opt.step()
outer_opt.zero_grad()
if __name__ == '__main__':
test_parametrized_inner_optimizer()
print('Done \a')
"""
output when deep copy is uncommented (parametrized optimizer trains properly):
episode = 0
base_mdl.grad = tensor([[-0.0351]])
opt_mdl.grad = tensor([[0.0085, 0.0000, 0.0204]])
episode = 1
base_mdl.grad = tensor([[0.0311]])
opt_mdl.grad = tensor([[-0.0086, -0.0100, 0.0358]])
episode = 2
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = tensor([[0., 0., 0.]])
episode = 3
base_mdl.grad = tensor([[0.0066]])
opt_mdl.grad = tensor([[-0.0016, 0.0000, -0.0032]])
episode = 4
base_mdl.grad = tensor([[-0.0311]])
opt_mdl.grad = tensor([[0.0077, 0.0000, 0.0130]])
Done
when deep copy is on (paremeters of inner optimizer are not train, sad!):
episode = 0
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 1
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 2
base_mdl.grad = tensor([[0.0069]])
opt_mdl.grad = None
episode = 3
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 4
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
Done
The deep copy line in higher I am referencing:
self.param_groups = _copy.deepcopy(other.param_groups)
#self.param_groups = other.param_groups
"""
Real Solution
The real solution is if I could pass an arbitrary dictionary to a differentiable optimizer and if I could do whatever I wanted with it.
#update:
perhaps this can be implemented with override:
override (optional) – a dictionary mapping optimizer settings (i.e. those which would be passed to the optimizer constructor or provided within parameter groups) to either singleton lists of override values, or to a list of override values of length equal to the number of parameter groups. If a single override is provided for a keyword, it is used for all parameter groups. If a list is provided, the ith element of the list overrides the corresponding setting in the ith parameter group. This permits the passing of tensors requiring gradient to differentiable optimizers for use as optimizer settings.
Didn’t work with override:
Exception has occurred: ValueError
Mismatch between the number of override tensors for optimizer parameter trainable_opt_model and the number of parameter groups.
seems like it checks that these lengths match...
def _apply_override(self, override: _OverrideType) -> None:
for k, v in override.items():
# Sanity check
if (len(v) != 1) and (len(v) != len(self.param_groups)
I think this is all I need:
inner_opt = EmptySimpleMetaLstm( base_mdl.parameters() )
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
cross-posted:
- git issue in higher: How does one implemented a parametrized meta-learner (like meta-lstm optimizer) in higher? · Issue #62 · facebookresearch/higher · GitHub
- Reddit - Dive into anything?
- How does one implemented a parametrized meta-learner in Pytorch's higher library?
related: