Meta-learning the outer gradient is None

I’m currently solving a meta-learning problem. First, I use the loss to update the parameters of the fmodel, and then use the updated fmodel to calculate the second-order gradient with respect to the meta-net. However, during the update process, I found that loss_meta has no gradient with respect to meta_net. Could everyone please help me find out where the problem lies?

The following is my code:

for train_iter in range(total_itr_num):
        try:
            inputs_x, outcome = next(dataloader_lb_itr)
        except:
            dataloader_lb_itr = iter(dataloader_lb)
            inputs_x, outcome = next(dataloader_lb_itr)
        inputs_x = inputs_x.to(device)
        outcome = outcome.to(device)
        targets_x = outcome.view(-1)
        indices = torch.randperm(len(inputs_x))
        shuffled_inputs_x = inputs_x[indices]
        shuffled_targets_x = targets_x[indices]
        inputs_x_1, inputs_x_2 = shuffled_inputs_x.chunk(2)
        targets_x_1, targets_x_2 = shuffled_targets_x.chunk(2)
        y.append(targets_x_1.view(-1).detach().cpu()
        try:
            (inputs_u_w, inputs_u_s), _ = next(dataloader_unlb_0_itr)
        except:
            dataloader_unlb_0_itr = iter(dataloader_unlb_0)
            (inputs_u_w, inputs_u_s), _ = next(dataloader_unlb_0_itr)
        inputs_u_w = inputs_u_w.to(device)
        inputs_u_s = inputs_u_s.to(device)  

        if train_iter % 1 == 0:
            fmodel.load_state_dict(model.state_dict())

            # label_loss
            batch_size = inputs_u_w.shape[0]
            mean_raw = fmodel(inputs_x_1) 
            label_loss = F.mse_loss(mean_raw, targets_x_1.unsqueeze(1))

            # unlabel_loss
            with torch.no_grad():
                ul_pred_w = fmodel(inputs_u_w) # pseudo label
            ul_pred_s = fmodel(inputs_u_s)
            unlabel_loss = F.mse_loss(ul_pred_s, ul_pred_w, reduction='none') * (1.0 / len(inputs_u_w))

            # meta weight
            weight = meta_net(unlabel_loss.detach())
            norm = torch.sum(weight)
            unlabel_loss_hat = torch.sum(weight * unlabel_loss) * (1.0 / norm)
            
            loss = label_loss  + w_ulb * unlabel_loss_hat
            
            # update the fmodel
            fmodel.zero_grad()
            grads = torch.autograd.grad(loss, (fmodel.parameters()), create_graph=True)
            update_params(fmodel, lr=0.0001, source_params=grads)
            del grads

            # update meta_net
            mean_meta = fmodel(inputs_x_2)
            targets_meta = targets_x_2.unsqueeze(1)
            label_meta = F.mse_loss(mean_meta, targets_meta)

            optim_meta.zero_grad()
            label_meta.backward()
            optim_meta.step()
            for param in meta_net.parameters():
                print(f"Gradient: {param.grad}")
            exit()

        # label_loss
        mean_raw = model(inputs_x_1)
        yhat_0.append(mean_raw.view(-1).to("cpu").detach())
        label_loss = F.mse_loss(mean_raw, targets_x_1.unsqueeze(1))

        # unlabel_loss
        with torch.no_grad():
            ul_pred_w = model(inputs_u_w) # pseudo label
        ul_pred_s = model(inputs_u_s)
        unlabel_loss = F.mse_loss(ul_pred_s, ul_pred_w, reduction='none') * (1.0 / len(inputs_u_w))
 
        # meta weight
        with torch.no_grad():
            weight = meta_net(unlabel_loss.detach())
            wei.append(weight.view(-1).to("cpu"))
            norm = torch.sum(weight)
        unlabel_loss_hat = torch.sum(weight * unlabel_loss) * (1.0 / norm)

        loss = label_loss + w_ulb * unlabel_loss_hat

        if train:
            optim.zero_grad()
            loss.backward()
            optim.step()
  • the following is the result

image

  • update function in the train process

    def update_params(model, lr, source_params):
        for tgt, src in zip(model.named_parameters(), source_params):
            name_t, param_t = tgt
            grad = src
            tmp = param_t - lr * grad
            set_param(model, name_t, tmp)
    
    
    def set_param(curr_mod, name, param):
        if '.' in name:
            n = name.split('.')
            module_name = n[0]
            rest = '.'.join(n[1:])
            for name, mod in curr_mod.named_children():
                if module_name == name:
                    set_param(mod, rest, param)
                    break
        else:
            if not isinstance(param, torch.nn.Parameter):
                param = torch.nn.Parameter(param)
            setattr(curr_mod, name, param)
            # getattr(curr_mod, name).data.copy_(param)