I tried using higher but something isn’t working…I’ll paste my code just in case you have time to take a look:
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
from collections import OrderedDict
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
class MySGD(Optimizer):
def __init__(self, params, eta, prev_lr):
defaults = {'eta':eta, 'prev_lr':prev_lr}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['prev_lr'] = lr
higher.register_optim(MySGD, TrainableSGD)
def main():
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10) )
]))
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1)),
('sigmoid', nn.Sigmoid())
]))
inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden)
meta_params = itertools.chain(child_model.parameters(),eta.parameters())
#meta_params = itertools.chain(eta.parameters(),[hidden])
meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
# do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print()
nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
nb_inner_steps = 3
#with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
if inner_i >= nb_inner_steps:
break
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'--> inner_i = {inner_i}')
print(f'inner_loss^<{inner_i}>: {inner_loss}')
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}')
diffopt.step(inner_loss) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
print()
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
make_dot(meta_loss).render('meta_loss',format='png')
meta_loss.backward()
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
print(f'hidden.grad = {hidden.grad}')
print(f'eta.fc.weight = {eta.fc.weight.grad}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
if __name__ == "__main__":
main()
print('---> Done\a')
notice the None’s:
Files already downloaded and verifiedFiles already downloaded and verified
-> hidden = tensor([[0.8459]], requires_grad=True)
--> inner_i = 0
inner_loss^<0>: 2.2696359157562256
lr^<-1> = tensor([[0.8459]], requires_grad=True)
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)
--> inner_i = 1
inner_loss^<1>: 2.0114920139312744
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)
--> inner_i = 2
inner_loss^<2>: 2.3866422176361084
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)
lr^<2> = tensor([0.0717], grad_fn=<MulBackward0>)
----> outer_i = 0
-> outer_loss/meta_loss^<0>: 4.021303176879883
child_model.fc.weight.grad = None
hidden.grad = None
eta.fc.weight = None
---> Done