How does one have the parameters of a model NOT BE LEAFS?

@albanD my thing is nearly working but the only issue is that pytorch thinks I’m doing a backward pass twice on the same graph. I am not sure why it thinks that because I delete the output node as you said (here: How to free graph manually?).

Do you know why it might not be working?

Code:

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

def load_new_params(optimizer, params):
    optimizer.param_groups = []

    param_groups = list(params)
    if len(param_groups) == 0:
        raise ValueError("optimizer got an empty parameter list")
    if not isinstance(param_groups[0], dict):
        param_groups = [{'params': param_groups}]
    for param_group in param_groups:
        optimizer.add_param_group(param_group)

def reload_param_groups(opt, params):
    if isinstance(params, torch.Tensor):
        raise TypeError("params argument given to the optimizer should be "
                        "an iterable of Tensors or dicts, but got " +
                        torch.typename(params))
    # replace params
    params = list(params)
    if isinstance(params[0], dict):
        raise ValueError(f'The hacked higher version does not support proper pytorch grouped params yet.')
    opt.param_groups[0]['params'] = params
    # opt.param_groups = []

    # param_groups = list(params)
    # if len(param_groups) == 0:
    #     raise ValueError("optimizer got an empty parameter list")
    # if not isinstance(param_groups[0], dict):
    #     param_groups = [{'params': param_groups}]

    # for param_group in param_groups:
    #     opt.add_param_group(param_group)

class MySGD(Optimizer):

    def __init__(self, params, trainable_opt_params, trainable_opt_state):
        defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_params']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.01*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                p_new = p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
        # update model
        # new_params = self.param_groups[0]['params'] 
        # new_params = self._track_higher_grads_for_new_params(new_params, self._track_higher_grads)
        # self._fmodel.update_params(new_params)

higher.register_optim(MySGD, TrainableSGD)

def main():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")    
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()
    # get trainable opt params
    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))
    lr = 0.01
    meta_params = []
    meta_params.append( {'params': hidden, 'lr':lr} )
    meta_params.append( {'params': eta.parameters(), 'lr':lr} )
    # get meta optimizer
    #meta_opt = torch.optim.SGD(meta_params)
    meta_opt = torch.optim.Adam(meta_params)
    #
    trainable_opt_params = {'eta':eta, 'hidden':hidden}
    trainable_opt_state = {'prev_lr':hidden}
    #inner_opt = MySGD(eta.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
    # diffopt = higher.optim.get_diff_optim(
    #     inner_opt,
    #     eta.parameters(), # for this hack it can be anything
    #     fmodel=None, # None
    #     device=device,
    #     override=None, # None default
    #     track_higher_grads=True # True default
    # )
    # do meta-training/ outerloop argmin L^val(theta)
    nb_outer_steps = 2 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # sample child_model
        child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))
        # do inner-training: ~ argmin L^train(psi)
        nb_inner_steps = 3   
        print('==== Inner Loop ====')
        fmodel = higher.patch.monkeypatch(
            child_model, 
            device, 
            copy_initial_weights=True # True default
        )
        inner_opt = MySGD(child_model.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        diffopt = higher.optim.get_diff_optim(
            inner_opt,
            child_model.parameters(), # for this hack it can be anything
            fmodel=fmodel, # None
            device=device,
            override=None, # None default
            track_higher_grads=True # True default
        )
        for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
            if inner_i >= nb_inner_steps:
                break
            print(f'-> outer_i = {outer_i}')                
            print(f'-> inner_i = {inner_i}')
            print(f'hidden^<{inner_i}> = {hidden}')
            logits = fmodel(inner_inputs)
            inner_loss = criterion(logits, inner_targets)
            print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
            #child_model_params = [{'params':child_model.parameters()}]
            child_model_params = child_model.parameters()
            reload_param_groups(diffopt, child_model_params)
            diffopt._fmodel = fmodel
            diffopt.step(inner_loss)
            print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
            print(f'hidden^<{inner_i}> = {hidden}')
        # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
        outer_outputs = fmodel(outer_inputs)
        meta_loss = criterion(outer_outputs, outer_targets) # L^val
        #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
        print('\n---- Outer loop print statements ----')
        print(f'----> outer_i = {outer_i}')
        print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
        #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
        meta_loss.backward()
        print(f'hidden.grad = {hidden.grad}')
        assert hidden.grad is not None 
        print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
        print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
        print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
        meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
        print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
        print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
        del meta_loss
        meta_opt.zero_grad()
        print()

if __name__ == "__main__":
    main()
    print('---> Done\a')