I’m currently solving a meta-learning problem. First, I use the loss to update the parameters of the fmodel, and then use the updated fmodel to calculate the second-order gradient with respect to the meta-net. However, during the update process, I found that loss_meta has no gradient with respect to meta_net. Could everyone please help me find out where the problem lies?
The following is my code:
for train_iter in range(total_itr_num):
try:
inputs_x, outcome = next(dataloader_lb_itr)
except:
dataloader_lb_itr = iter(dataloader_lb)
inputs_x, outcome = next(dataloader_lb_itr)
inputs_x = inputs_x.to(device)
outcome = outcome.to(device)
targets_x = outcome.view(-1)
indices = torch.randperm(len(inputs_x))
shuffled_inputs_x = inputs_x[indices]
shuffled_targets_x = targets_x[indices]
inputs_x_1, inputs_x_2 = shuffled_inputs_x.chunk(2)
targets_x_1, targets_x_2 = shuffled_targets_x.chunk(2)
y.append(targets_x_1.view(-1).detach().cpu()
try:
(inputs_u_w, inputs_u_s), _ = next(dataloader_unlb_0_itr)
except:
dataloader_unlb_0_itr = iter(dataloader_unlb_0)
(inputs_u_w, inputs_u_s), _ = next(dataloader_unlb_0_itr)
inputs_u_w = inputs_u_w.to(device)
inputs_u_s = inputs_u_s.to(device)
if train_iter % 1 == 0:
fmodel.load_state_dict(model.state_dict())
# label_loss
batch_size = inputs_u_w.shape[0]
mean_raw = fmodel(inputs_x_1)
label_loss = F.mse_loss(mean_raw, targets_x_1.unsqueeze(1))
# unlabel_loss
with torch.no_grad():
ul_pred_w = fmodel(inputs_u_w) # pseudo label
ul_pred_s = fmodel(inputs_u_s)
unlabel_loss = F.mse_loss(ul_pred_s, ul_pred_w, reduction='none') * (1.0 / len(inputs_u_w))
# meta weight
weight = meta_net(unlabel_loss.detach())
norm = torch.sum(weight)
unlabel_loss_hat = torch.sum(weight * unlabel_loss) * (1.0 / norm)
loss = label_loss + w_ulb * unlabel_loss_hat
# update the fmodel
fmodel.zero_grad()
grads = torch.autograd.grad(loss, (fmodel.parameters()), create_graph=True)
update_params(fmodel, lr=0.0001, source_params=grads)
del grads
# update meta_net
mean_meta = fmodel(inputs_x_2)
targets_meta = targets_x_2.unsqueeze(1)
label_meta = F.mse_loss(mean_meta, targets_meta)
optim_meta.zero_grad()
label_meta.backward()
optim_meta.step()
for param in meta_net.parameters():
print(f"Gradient: {param.grad}")
exit()
# label_loss
mean_raw = model(inputs_x_1)
yhat_0.append(mean_raw.view(-1).to("cpu").detach())
label_loss = F.mse_loss(mean_raw, targets_x_1.unsqueeze(1))
# unlabel_loss
with torch.no_grad():
ul_pred_w = model(inputs_u_w) # pseudo label
ul_pred_s = model(inputs_u_s)
unlabel_loss = F.mse_loss(ul_pred_s, ul_pred_w, reduction='none') * (1.0 / len(inputs_u_w))
# meta weight
with torch.no_grad():
weight = meta_net(unlabel_loss.detach())
wei.append(weight.view(-1).to("cpu"))
norm = torch.sum(weight)
unlabel_loss_hat = torch.sum(weight * unlabel_loss) * (1.0 / norm)
loss = label_loss + w_ulb * unlabel_loss_hat
if train:
optim.zero_grad()
loss.backward()
optim.step()
- the following is the result
-
update function in the train process
def update_params(model, lr, source_params): for tgt, src in zip(model.named_parameters(), source_params): name_t, param_t = tgt grad = src tmp = param_t - lr * grad set_param(model, name_t, tmp) def set_param(curr_mod, name, param): if '.' in name: n = name.split('.') module_name = n[0] rest = '.'.join(n[1:]) for name, mod in curr_mod.named_children(): if module_name == name: set_param(mod, rest, param) break else: if not isinstance(param, torch.nn.Parameter): param = torch.nn.Parameter(param) setattr(curr_mod, name, param) # getattr(curr_mod, name).data.copy_(param)