How does one have the parameters of a model NOT BE LEAFS?

self contained script that seems to work:

import torch
import torch.nn as nn

from torchviz import make_dot

import copy

from collections import OrderedDict

# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2

criterion = nn.CrossEntropyLoss()

#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))

hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

nb_updates = 2
for i in range(nb_updates):
    print(f'i = {i}')
    new_params = copy.deepcopy( loss_net.state_dict() )
    ## w^<t> := f(w^<t-1>,delta^<t-1>)
    for (name, w) in list(loss_net.named_parameters()):
        hidden = updater_net(hidden).view(1)
        #delta = ((hidden**2)*w/2)
        delta = w + hidden
        wt = w + delta
        del_attr(loss_net, name.split("."))
        set_attr(loss_net, name.split("."), wt)
    ##
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
print(f'loss_net.fc0.weight.is_leaf = {loss_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}') # None because this is not a leaf, it is overriden in the for loop above.
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
make_dot(loss_val)

output:

updater_net.fc0.weight.is_leaf = True
i = 0
i = 1

updater_net.fc0.weight.is_leaf = True
loss_net.fc0.weight.is_leaf = False

-- params that dont matter if they have gradients --
loss_net.grad = None
-- params we want to have gradients --
hidden.grad = None
updater_net.fc0.weight.grad = tensor([[0.7152]])
updater_net.fc0.bias.grad = tensor([-7.4249])

Since you have arbitrary name, but you need to be able to delete and set attributes (to delete the nn.Parameter and set a Tensor instead). So this is a recursive function that given a list of all the names like [“foo”, “bar”, “weight”] will either set or delete obj.foo.bar.weight.

AlbanD, this does quite work. I wanted to make sure it worked by printing the computation graph and increasing the nb_updates should increase the computation graph but it doesn’t. I believe it has something to do with the fact we are changing/mutating the OrderDict/params and looping through it at the same time. I tried to fix it by collecting the parameters before and after running the loop but the somehow it thinks its empty for the second loop (i.e. t=1 doesn’t print anything when it should):

import torch
import torch.nn as nn

from torchviz import make_dot

import copy

from collections import OrderedDict

def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
        
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

        
x = torch.randn(1, requires_grad=True)
y = torch.randn(1, requires_grad=True)

criterion = nn.CrossEntropyLoss()

loss_net = nn.Sequential(OrderedDict([('l_fc0', nn.Linear(in_features=1,out_features=1, bias=True))]))
loss_net.l_fc0.weight.requires_grad=False
loss_net.l_fc0.bias.requires_grad=False

hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('u_fc0',nn.Linear(in_features=1,out_features=1))]))
updater_net.u_fc0.bias.requires_grad = False
print(f'updater_net.u_fc0.weight.is_leaf = {updater_net.u_fc0.weight.is_leaf}')
#
outputs_virgin = loss_net(x)
params = dict(dict(loss_net.named_parameters()),**{'x':x})
make_dot(outputs_virgin, params=params).render('loss_net_x', format='png')
#
nb_updates = 2
params = list(loss_net.named_parameters())
for t in range(nb_updates):
    print(f't = {t}')
    ## w^<t> := f(w^<t-1>,delta^<t-1>)
    for (name, w) in params:
        delta = updater_net(hidden).view(1)
        wt = w + delta
        print(f'w^<{t}> = {wt}')
        del_attr(loss_net, name.split("."))
        set_attr(loss_net, name.split("."), wt)
    params = list(loss_net.named_parameters())

print()
print(f'updater_net.u_fc0.weight.is_leaf = {updater_net.u_fc0.weight.is_leaf}')
print(f'loss_net.l_fc0.weight.is_leaf = {loss_net.l_fc0.weight.is_leaf}')

outputs = loss_net(x)
loss_val = (outputs - y)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.l_fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}') # None because this is not a leaf, it is overriden in the for loop above.
print(f'updater_net.u_fc0.weight.grad = {updater_net.u_fc0.weight.grad}')
#print(f'updater_net.u_fc0.bias.grad = {updater_net.u_fc0.bias.grad}')
params = dict(dict(updater_net.named_parameters()),**{'x':x,'y':y,'hidden':hidden},**dict(loss_net.named_parameters()))
make_dot(loss_val, params=params).render('loss_val', format='png')

Well you should not use loss_net.named_parameters() anymore. we override them. So you should “save” them before doing the first override. Like params = list(loss_net.named_parameters()) before the for-loop.
Then you can set them back (as leafs) when you want to make a new iteration. So maybe extract them as a state_dict and restore them with load_state_dict.

I don’t think we should do that, if we do they would be used as non-leafs in the next iteration and won’t appear in the computation graph properly. It should be a chain of caused by the iterative use of wt.

Let me think about it…

I also thought we could have some special string like the word param in the field when we set it and thus we could loop through all the fields of the object that contain that special string.

Note that if the goal is to do learning through your optimizer step, the higher library already implements this. And has a nice API for for all that :slight_smile:

2 Likes

thanks for that!

I’ve been playing around with the library but was wondering if it was possible to have a trainable step-size with that library. Is it possible?

I tried:

#
child_model = nn.Sequential(OrderedDict([
        ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5)),
        ('relu1', nn.ReLU()),
        ('Flatten', Flatten()),
        ('fc', nn.Linear(in_features=28*28*2,out_features=10) )
    ]))
eta = nn.Sequential(OrderedDict([
    ('fc', nn.Linear(1,1)),
    ('sigmoid', nn.Sigmoid())
]))
inner_opt = torch.optim.Adam(child_model.parameters(), lr=eta)
meta_params = itertools.chain(child_model.parameters(),eta.parameters())
meta_opt = torch.optim.Adam(meta_params, lr=1e-3)

but it failed with error:

Exception has occurred: TypeError
'<=' not supported between instances of 'float' and 'Sequential'

Or even better the update rule be some sort of NN…

starting to think that going back to your settatrr approach might be better X’D…

You can simply replace the optimizer by a nn ?
Or fill-in the .grad fiels with your nn’s result and then use a simple SGD to do the step.

That’s what we were doing (pretty much) with our first example (the one you suggested the modified set_attr & del_attr).

Or perhaps you mean to implement my own custom optimizer and inside of it let it have/be a nn and then use the higher library so that gradients flow all the way (as in the example that you made work but for more than 1 step)?

I will go try that right now…

Oh, that’s an interesting idea. I wonder if the gradients will flow to the beginning as I required as in the example you made work. The fill in for .grad has to not be in place because I want gradients to flow through this step…

I tried using higher but something isn’t working…I’ll paste my code just in case you have time to take a look:


import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

from collections import OrderedDict

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

class MySGD(Optimizer):

    def __init__(self, params, eta, prev_lr):
        defaults = {'eta':eta, 'prev_lr':prev_lr}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['prev_lr'] = lr

higher.register_optim(MySGD, TrainableSGD)

def main():
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()

    child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10) )
        ]))

    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1)),
        ('sigmoid', nn.Sigmoid())
    ]))
    inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden)
    meta_params = itertools.chain(child_model.parameters(),eta.parameters())
    #meta_params = itertools.chain(eta.parameters(),[hidden])
    meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    print()
    nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
        nb_inner_steps = 3
        #with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
        with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                if inner_i >= nb_inner_steps:
                    break
                logits = fmodel(inner_inputs)
                inner_loss = criterion(logits, inner_targets)
                print(f'--> inner_i = {inner_i}')
                print(f'inner_loss^<{inner_i}>: {inner_loss}')
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}') 
                diffopt.step(inner_loss) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
                print()
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            make_dot(meta_loss).render('meta_loss',format='png')
            meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            print(f'eta.fc.weight = {eta.fc.weight.grad}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )

if __name__ == "__main__":
    main()
    print('---> Done\a')

notice the None’s:

Files already downloaded and verifiedFiles already downloaded and verified
-> hidden = tensor([[0.8459]], requires_grad=True)

--> inner_i = 0
inner_loss^<0>: 2.2696359157562256
lr^<-1> = tensor([[0.8459]], requires_grad=True)
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)

--> inner_i = 1
inner_loss^<1>: 2.0114920139312744
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)

--> inner_i = 2
inner_loss^<2>: 2.3866422176361084
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)
lr^<2> = tensor([0.0717], grad_fn=<MulBackward0>)

----> outer_i = 0
-> outer_loss/meta_loss^<0>: 4.021303176879883
child_model.fc.weight.grad = None
hidden.grad = None
eta.fc.weight = None
---> Done

related:

The git issue if for the same thing? Does it answer your problem?

yes it’s the same thing as I posted here.

My problem is NOT solved. I’m working on it.

@albanD my thing is nearly working but the only issue is that pytorch thinks I’m doing a backward pass twice on the same graph. I am not sure why it thinks that because I delete the output node as you said (here: How to free graph manually?).

Do you know why it might not be working?

Code:

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

def load_new_params(optimizer, params):
    optimizer.param_groups = []

    param_groups = list(params)
    if len(param_groups) == 0:
        raise ValueError("optimizer got an empty parameter list")
    if not isinstance(param_groups[0], dict):
        param_groups = [{'params': param_groups}]
    for param_group in param_groups:
        optimizer.add_param_group(param_group)

def reload_param_groups(opt, params):
    if isinstance(params, torch.Tensor):
        raise TypeError("params argument given to the optimizer should be "
                        "an iterable of Tensors or dicts, but got " +
                        torch.typename(params))
    # replace params
    params = list(params)
    if isinstance(params[0], dict):
        raise ValueError(f'The hacked higher version does not support proper pytorch grouped params yet.')
    opt.param_groups[0]['params'] = params
    # opt.param_groups = []

    # param_groups = list(params)
    # if len(param_groups) == 0:
    #     raise ValueError("optimizer got an empty parameter list")
    # if not isinstance(param_groups[0], dict):
    #     param_groups = [{'params': param_groups}]

    # for param_group in param_groups:
    #     opt.add_param_group(param_group)

class MySGD(Optimizer):

    def __init__(self, params, trainable_opt_params, trainable_opt_state):
        defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_params']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.01*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                p_new = p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
        # update model
        # new_params = self.param_groups[0]['params'] 
        # new_params = self._track_higher_grads_for_new_params(new_params, self._track_higher_grads)
        # self._fmodel.update_params(new_params)

higher.register_optim(MySGD, TrainableSGD)

def main():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")    
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()
    # get trainable opt params
    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))
    lr = 0.01
    meta_params = []
    meta_params.append( {'params': hidden, 'lr':lr} )
    meta_params.append( {'params': eta.parameters(), 'lr':lr} )
    # get meta optimizer
    #meta_opt = torch.optim.SGD(meta_params)
    meta_opt = torch.optim.Adam(meta_params)
    #
    trainable_opt_params = {'eta':eta, 'hidden':hidden}
    trainable_opt_state = {'prev_lr':hidden}
    #inner_opt = MySGD(eta.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
    # diffopt = higher.optim.get_diff_optim(
    #     inner_opt,
    #     eta.parameters(), # for this hack it can be anything
    #     fmodel=None, # None
    #     device=device,
    #     override=None, # None default
    #     track_higher_grads=True # True default
    # )
    # do meta-training/ outerloop argmin L^val(theta)
    nb_outer_steps = 2 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # sample child_model
        child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))
        # do inner-training: ~ argmin L^train(psi)
        nb_inner_steps = 3   
        print('==== Inner Loop ====')
        fmodel = higher.patch.monkeypatch(
            child_model, 
            device, 
            copy_initial_weights=True # True default
        )
        inner_opt = MySGD(child_model.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        diffopt = higher.optim.get_diff_optim(
            inner_opt,
            child_model.parameters(), # for this hack it can be anything
            fmodel=fmodel, # None
            device=device,
            override=None, # None default
            track_higher_grads=True # True default
        )
        for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
            if inner_i >= nb_inner_steps:
                break
            print(f'-> outer_i = {outer_i}')                
            print(f'-> inner_i = {inner_i}')
            print(f'hidden^<{inner_i}> = {hidden}')
            logits = fmodel(inner_inputs)
            inner_loss = criterion(logits, inner_targets)
            print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
            #child_model_params = [{'params':child_model.parameters()}]
            child_model_params = child_model.parameters()
            reload_param_groups(diffopt, child_model_params)
            diffopt._fmodel = fmodel
            diffopt.step(inner_loss)
            print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
            print(f'hidden^<{inner_i}> = {hidden}')
        # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
        outer_outputs = fmodel(outer_inputs)
        meta_loss = criterion(outer_outputs, outer_targets) # L^val
        #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
        print('\n---- Outer loop print statements ----')
        print(f'----> outer_i = {outer_i}')
        print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
        #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
        meta_loss.backward()
        print(f'hidden.grad = {hidden.grad}')
        assert hidden.grad is not None 
        print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
        print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
        print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
        meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
        print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
        print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
        del meta_loss
        meta_opt.zero_grad()
        print()

if __name__ == "__main__":
    main()
    print('---> Done\a')


Most likely because you re-use some pre-computed states from one iteration to the next?
Not sure why you have to manually unset/set the fmodel and the different elements from higher…

So the issue is that this line of code of higher

self.param_groups = _copy.deepcopy(other.param_groups)

breaks the trainable step size I am trying to build.

I tried uncommenting it before but my code was still breaking.

With a lot of exploration it seems that only when I re-instantiate/rebuild the inner optimizer + differentiable optimizer before every inner loop then the code works (I think…)

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

class MySGD(Optimizer):

    def __init__(self, params, trainable_opt_params, trainable_opt_state):
        defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_params']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.01*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                p_new = p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr

higher.register_optim(MySGD, TrainableSGD)

def main():
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()

    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))

    lr = 0.01
    meta_params = []
    meta_params.append( {'params': hidden, 'lr':lr} )
    meta_params.append( {'params': eta.parameters(), 'lr':lr} )
    #meta_opt = torch.optim.SGD(meta_params)
    meta_opt = torch.optim.Adam(meta_params)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    nb_outer_steps = 5 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        #
        child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))
        # do inner-training: ~ argmin L^train(theta)
        nb_inner_steps = 3
        trainable_opt_params = {'eta':eta, 'hidden':hidden}
        trainable_opt_state = {'prev_lr':hidden}
        child_model_params = [{'params':child_model.parameters()}]
        inner_opt = MySGD(child_model_params, trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        print('==== Inner Loop ====')
        with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                if inner_i >= nb_inner_steps:
                    break
                print(f'-> outer_i = {outer_i}')                
                print(f'-> inner_i = {inner_i}')
                print(f'hidden^<{inner_i}> = {hidden}')
                logits = fmodel(inner_inputs)
                inner_loss = criterion(logits, inner_targets)
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                diffopt.step(inner_loss)
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                print(f'hidden^<{inner_i}> = {hidden}')
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print('\n---- Outer loop print statements ----')
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            assert hidden.grad is not None
            assert eta.fc.weight is not None
            print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
            print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
            print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
            print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            print()

if __name__ == "__main__":
    main()
    print('---> Done\a')

this works now:

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

class MySGD(Optimizer):

    def __init__(self, params, trainable_opt_params, trainable_opt_state):
        defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_params']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.01*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                p_new = p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr

higher.register_optim(MySGD, TrainableSGD)

def main():
    use_cuda = torch.cuda.is_available() 
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f'device = {device}')
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()
    # get inner opt
    hidden = torch.randn(size=(1,1), device=device, requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1, bias=False)),
        ('sigmoid', nn.Sigmoid())
    ])).to(device)
    trainable_opt_params = {'eta':eta, 'hidden':hidden}
    trainable_opt_state = {'prev_lr':hidden}
    # get outer opt
    lr = 0.05
    meta_params = []
    meta_params.append( {'params': hidden, 'lr':lr} )
    meta_params.append( {'params': eta.parameters(), 'lr':lr} )
    #meta_opt = torch.optim.SGD(meta_params)
    meta_opt = torch.optim.Adam(meta_params)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    nb_outer_steps = 10 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        outer_inputs, outer_targets = outer_inputs.to(device), outer_targets.to(device)
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        #
        child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ])).to(device)
        # do inner-training: ~ argmin L^train(theta)
        nb_inner_steps = 2
        trainable_opt_params = {'eta':eta, 'hidden':hidden}
        trainable_opt_state = {'prev_lr':hidden}
        print(trainable_opt_state)
        child_model_params = [{'params':child_model.parameters()}]
        inner_opt = MySGD(child_model_params, trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        print('==== Inner Loop ====')
        with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=True) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                inner_inputs, inner_targets = inner_inputs.to(device), inner_targets.to(device)
                if inner_i >= nb_inner_steps:
                    break
                print(f'-> outer_i = {outer_i}')                
                print(f'-> inner_i = {inner_i}')
                print(f'hidden^<{inner_i}> = {hidden}')
                logits = fmodel(inner_inputs)
                inner_loss = criterion(logits, inner_targets)
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                diffopt.step(inner_loss)
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                print(f'hidden^<{inner_i}> = {hidden}')
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print('\n---- Outer loop print statements ----')
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            assert hidden.grad is not None
            assert eta.fc.weight is not None
            print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
            print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
            print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
            print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            print()

if __name__ == "__main__":
    main()
    print('---> Done\a')

I got lucky. It breaks if you uncomment:

        trainable_opt_state = {'prev_lr':hidden}

inside here:

        # do inner-training: ~ argmin L^train(theta)
        nb_inner_steps = 2
        trainable_opt_params = {'eta':eta, 'hidden':hidden}
        trainable_opt_state = {'prev_lr':hidden}
        print(trainable_opt_state)
        child_model_params = [{'params':child_model.parameters()}]
        inner_opt = MySGD(child_model_params, trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        print('==== Inner Loop ====')

I believe is because the next time I compute meta_loss it has a reference to the previous computation graph because prev_lr belongs to the previous iteration or the previous outerloop.

2 Likes