I am currently working on re-implementing the paper “Towards Fast, Accurate and Stable 3D Dense Face Alignment,” and I’m having trouble figuring out the implementation of the Meta-Joint Optimization part. Generally, it seems to be about model selection every k steps. There are two loss functions involved: VDC and wPDC. Every K iterations, the model parameters Theta are updated using these two loss functions, resulting in two different sets of parameters: Theta_VDC and Theta_wPDC. Then, these parameters are validated on a batch, and the one with the better metric is chosen as the initial parameters for the next iteration. This is my implement, but it always give inf loss after several iter.
model_vdc = mobilenet()
model_wpdc = copy.deepcopy(model_vdc)
optimizer_vdc = torch.optim.AdamW(params=model_vdc.parameters(),lr=lr)
optimizer_wpdc = torch.optim.AdamW(params=model_wpdc.parameters(),lr=lr)
for epoch in range(N):
for batch_idx,batch in enumerate(trainloader):
if batch_idx == 0 or (batch_idx != 0 and batch_idx % meta_joint_k != 0):
# update by vdc loss
loss_vdc.backward()
optimizer_vdc.step()
optimizer_vdc.zero_grad()
# update by wpdc loss
loss_wpdc.backward()
optimizer_wpdc.step()
optimizer_wpdc.zero_grad()
elif batch_idx != 0 and batch_idx % meta_joint_k == 0:
model_vdc.eval();model_wpdc.eval()
# calculate the metric
......
if metric_vdc > metric_wpdc:
model_vdc.load_state_dict(copy.deepcopy(model_wpdc))
optimizer_vdc.load_state_dict(copy.deepcopy(optimizer_wpdc))
else:
model_wpdc.load_state_dict(copy.deepcopy(model_vdc))
optimizer_wpdc.load_state_dict(copy.deepcopy(optimizer_vdc))
model_vdc.training();model_wpdc.training()
Thank you in advance.