Higher Order Derivatives - Meta Learning

I am trying to write code for some of the Meta-Learning algorithms. I understand that there are a few packages available for easy and hassle-free implementation of Meta-Learning algorithms (higher, pytorch-meta) but I want to understand a few things conceptually.

Recently, a few meta-learning algorithm implementations such as Learning to Reweight, Meta-Weight Net, etc. have not been using higher or pytorch-meta, and, instead, have been using a custom nn.Module (see code below) written by Daniel(link for this code: https://github.com/danieltan07/learning-to-reweight-examples/blob/master/meta_layers.py). Basically, it’s the usual PyTorch code for supervised learning with the only change being: using custom nn.Module instead of the nn.Module.

I am pasting the relevant code below. The nn.Module shown below (Daniel’s code) is what people have been using for their meta learning algorithms thereby not requiring the additional packages I have mentioned above.

class MetaModule(nn.Module):
    # adopted from: Adrien Ecoffet https://github.com/AdrienLE
    def params(self):
       for name, param in self.named_params(self):
            yield param
    
    def named_leaves(self):
        return []
    
    def named_submodules(self):
        return []
    
    def named_params(self, curr_module=None, memo=None, prefix=''):       
        if memo is None:
            memo = set()

        if hasattr(curr_module, 'named_leaves'):
            for name, p in curr_module.named_leaves():
                if p is not None and p not in memo:
                    memo.add(p)
                    yield prefix + ('.' if prefix else '') + name, p
        else:
            for name, p in curr_module._parameters.items():
                if p is not None and p not in memo:
                    memo.add(p)
                    yield prefix + ('.' if prefix else '') + name, p
                    
        for mname, module in curr_module.named_children():
            submodule_prefix = prefix + ('.' if prefix else '') + mname
            for name, p in self.named_params(module, memo, submodule_prefix):
                yield name, p
    
    def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
        if source_params is not None:
            for tgt, src in zip(self.named_params(self), source_params):
                name_t, param_t = tgt
                # name_s, param_s = src
                # grad = param_s.grad
                # name_s, param_s = src
                grad = src
                if first_order:
                    grad = to_var(grad.detach().data)
                tmp = param_t - lr_inner * grad
                self.set_param(self, name_t, tmp)
        else:

            for name, param in self.named_params(self):
                if not detach:
                    grad = param.grad
                    if first_order:
                        grad = to_var(grad.detach().data)
                    tmp = param - lr_inner * grad
                    self.set_param(self, name, tmp)
                else:
                    param = param.detach_()
                    self.set_param(self, name, param)

    def set_param(self,curr_mod, name, param):
        if '.' in name:
            n = name.split('.')
            module_name = n[0]
            rest = '.'.join(n[1:])
            for name, mod in curr_mod.named_children():
                if module_name == name:
                    self.set_param(mod, rest, param)
                    break
        else:
            setattr(curr_mod, name, param)
            
    def detach_params(self):
        for name, param in self.named_params(self):
            self.set_param(self, name, param.detach())   
                
    def copy(self, other, same_var=False):
        for name, param in other.named_params():
            if not same_var:
                param = to_var(param.data.clone(), requires_grad=True)
            self.set_param(name, param)

Using such MetaModule one can create MetaLinear, MetaConv2D, etc. which can be used instead of nn.Linear, nn.Conv2D:

class MetaLinear(MetaModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        ignore = nn.Linear(*args, **kwargs)
       
        self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
        self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
        
    def forward(self, x):
        return F.linear(x, self.weight, self.bias)
    
    def named_leaves(self):
        return [('weight', self.weight), ('bias', self.bias)]
    
class MetaConv2d(MetaModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        ignore = nn.Conv2d(*args, **kwargs)
        
        self.stride = ignore.stride
        self.padding = ignore.padding
        self.dilation = ignore.dilation
        self.groups = ignore.groups
        
        self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
        
        if ignore.bias is not None:
            self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
        else:
            self.register_buffer('bias', None)
        
    def forward(self, x):
        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
    
    def named_leaves(self):
        return [('weight', self.weight), ('bias', self.bias)]

I have the following question:

  • Why should one need to create a custom nn.Module and then do setattr(name, params) to update nn.Parameters (as is done in the update_params function in MetaModule class) in a way that these operations are recorded in the computation graph as well? Why can’t I directly use setattr(name, params) in my regular training loop (i.e. def train(*args, **kwargs) function) with standard, built-in nn.Module?

  • I do understand that there’s another way to deal with this (link: [resolved] Implementing MAML in PyTorch). For instance, the def forward(self, x) function for nn.Module can be modified to def forward(self, x, weights) instead so that the code works. But I am not fully clear about this as well. I understand the nn.Parameters don’t record history and hence we need to operator on other Tensors and then copy those values to nn.Parameters but I wonder if I can do what I want without having to resort to this technique as well.

Note: My question is somewhat similar to Second order derivatives in meta-learning. However, what I am asking is conceptual and not necessarily a request for a workaround. And the only response in that thread is mine, so I am still not clear about implementing meta-learning algorithms in PyTorch.

1 Like

Hi,

  • I think the set_param function here is mainly built to be able to handle nested names. For example if your module is sequential that contains a conv. Then the name will be 0.weight. But you cannot use python’s setattr on that, you need to do first access “0” then “weight”.
  • I don’t think you can get around some kind of logic like that. The main reason being that you don’t want to override the original Parameters. Because you need to be able to backpropagate all the way back to them to update them. And so the intermediary Tensors can’t just be these Parameters modified inplace.

Hope this helps.

Thanks @albanD for responding so quickly. I understand your points and the part about handling nested names. But I want understand why is it a problem when I use setattr() for each of those weights/biases separately via MetaModule - something like this:

new_name_params = ...
for xxx in modules_of_network:
   for name, p in xxx.named_parameters():
       setattr(xxx, name, new_named_params[name])

I think I understand the problems in this approach to some degree but I am not clear how using setattr() in a custom nn.Module - MetaModule doesn’t throw the same kind of errors even though the parameters witness an in-place update.

setattr is actually the same as doing xxx.name = new_named_params[name].
So if it is already a nn.Parameter, you won’t be able to set a Tensor with history there.

I understand. But that still doesn’t explain how MetaModule enables Meta-Learning without using packages such as higher and others.

Hi,

It handles the params is a different way compared to the regular nn.Module. In particular, it allows the parameters to have some history associated with them by not having them be nn.Parameters.

I think what you are trying to say is that if I want the nn.Parameters to “record” history, the example I talked about above uses register_buffer instead of nn.Parameter as a neat hack.

I think it’s starting to make sense now. Here’s what I think is going on (code: Daniel’s code for ‘Learning To Reweight’ algorithm):

def train_lre():
    net, opt = build_model() # uses MetaModule to create the model instead of nn.Module
    
    meta_losses_clean = []
    net_losses = []
    plot_step = 100

    smoothing_alpha = 0.9
    
    meta_l = 0
    net_l = 0
    accuracy_log = []
    for i in tqdm(range(hyperparameters['num_iterations'])):
        net.train()
        # Line 2 get batch of data
        image, labels = next(iter(data_loader))
        # since validation data is small I just fixed them instead of building an iterator
        # initialize a dummy network for the meta learning of the weights
        meta_net = LeNet(n_out=1)
        meta_net.load_state_dict(net.state_dict())

        if torch.cuda.is_available():
            meta_net.cuda()

        image = to_var(image, requires_grad=False)
        labels = to_var(labels, requires_grad=False)

        # Lines 4 - 5 initial forward pass to compute the initial weighted loss
        y_f_hat  = meta_net(image)
        cost = F.binary_cross_entropy_with_logits(y_f_hat,labels, reduce=False)
        eps = to_var(torch.zeros(cost.size()))
        l_f_meta = torch.sum(cost * eps)

        meta_net.zero_grad()
        
        # Line 6 perform a parameter update
        grads = torch.autograd.grad(l_f_meta, (meta_net.params()), create_graph=True)
        meta_net.update_params(hyperparameters['lr'], source_params=grads)
        
        # Line 8 - 10 2nd forward pass and getting the gradients with respect to epsilon
        y_g_hat = meta_net(val_data)

        l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat,val_labels)

        grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        
        # Line 11 computing and normalizing the weights
        w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde

        # Lines 12 - 14 computing for the loss with the computed weights
        # and then perform a parameter update
        y_f_hat = net(image)
        cost = F.binary_cross_entropy_with_logits(y_f_hat, labels, reduce=False)
        l_f = torch.sum(cost * w)

        opt.zero_grad()
        l_f.backward()
        opt.step()

    return np.mean(acc_log[-6:-1, 1])

The trick here is to use the buffers (created via register_buffer in MetaLinear, etc.) to access nn.Parameter to create trainable tensors (weight and bias) and use them (instead of nn.Parameters which don’t record any history) via named_leaves, named_params, update_params, and set_param functions to do the meta-learning. The nn.Parameters are updated via opt.step() whereas all the intermediate nn.Parameter updates required for meta-learning are handled via the buffer variables (viz. weight and bias created via ‘register_buffer’).

@albanD Sorry for bothering you so much but I think that does explain it, don’t you think?

1 Like

Hi,

Yes I think it does.
Note that having the intermediary results as buffer vs just regular attributes doesn’t change much.
It will only change when you get the state dict or move the module to a different device. But hopefully you should not be doing that in the middle of the forward pass :wink:

I didn’t get you, @albanD. Can you elaborate a bit?

For any Tensor in the nn.Module, you can store it in self by doing self.foo = your_tensor or by doing self.register_buffer("foo", your_tensor).
I think that both will have the behavior that you want: you can access them by using self.foo and they can have history.

I think that the first one might be simpler to read as it is basic python semantic.
I can’t think of any reason why you need it to actually be a buffer (maybe I’m missing something though).