Higher Order Derivatives - Meta Learning

I think what you are trying to say is that if I want the nn.Parameters to “record” history, the example I talked about above uses register_buffer instead of nn.Parameter as a neat hack.

I think it’s starting to make sense now. Here’s what I think is going on (code: Daniel’s code for ‘Learning To Reweight’ algorithm):

def train_lre():
    net, opt = build_model() # uses MetaModule to create the model instead of nn.Module
    
    meta_losses_clean = []
    net_losses = []
    plot_step = 100

    smoothing_alpha = 0.9
    
    meta_l = 0
    net_l = 0
    accuracy_log = []
    for i in tqdm(range(hyperparameters['num_iterations'])):
        net.train()
        # Line 2 get batch of data
        image, labels = next(iter(data_loader))
        # since validation data is small I just fixed them instead of building an iterator
        # initialize a dummy network for the meta learning of the weights
        meta_net = LeNet(n_out=1)
        meta_net.load_state_dict(net.state_dict())

        if torch.cuda.is_available():
            meta_net.cuda()

        image = to_var(image, requires_grad=False)
        labels = to_var(labels, requires_grad=False)

        # Lines 4 - 5 initial forward pass to compute the initial weighted loss
        y_f_hat  = meta_net(image)
        cost = F.binary_cross_entropy_with_logits(y_f_hat,labels, reduce=False)
        eps = to_var(torch.zeros(cost.size()))
        l_f_meta = torch.sum(cost * eps)

        meta_net.zero_grad()
        
        # Line 6 perform a parameter update
        grads = torch.autograd.grad(l_f_meta, (meta_net.params()), create_graph=True)
        meta_net.update_params(hyperparameters['lr'], source_params=grads)
        
        # Line 8 - 10 2nd forward pass and getting the gradients with respect to epsilon
        y_g_hat = meta_net(val_data)

        l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat,val_labels)

        grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
        
        # Line 11 computing and normalizing the weights
        w_tilde = torch.clamp(-grad_eps,min=0)
        norm_c = torch.sum(w_tilde)

        if norm_c != 0:
            w = w_tilde / norm_c
        else:
            w = w_tilde

        # Lines 12 - 14 computing for the loss with the computed weights
        # and then perform a parameter update
        y_f_hat = net(image)
        cost = F.binary_cross_entropy_with_logits(y_f_hat, labels, reduce=False)
        l_f = torch.sum(cost * w)

        opt.zero_grad()
        l_f.backward()
        opt.step()

    return np.mean(acc_log[-6:-1, 1])

The trick here is to use the buffers (created via register_buffer in MetaLinear, etc.) to access nn.Parameter to create trainable tensors (weight and bias) and use them (instead of nn.Parameters which don’t record any history) via named_leaves, named_params, update_params, and set_param functions to do the meta-learning. The nn.Parameters are updated via opt.step() whereas all the intermediate nn.Parameter updates required for meta-learning are handled via the buffer variables (viz. weight and bias created via ‘register_buffer’).

@albanD Sorry for bothering you so much but I think that does explain it, don’t you think?

1 Like