Recently I read a paper about an overview of meta pseudo label and I found a Notebook that implement it. I saw a line of code ( (t_l_loss+t_mpl_loss).backward()
) and I don’t understand how it works.
Could you please give me a hint that how we can sum two losses and call backward()?
`
experiment4: meta-pesudo-label ( exact, no approximation)
if 1:
l_x = torch.from_numpy(x_train).float().cuda()
l_y = torch.from_numpy(y_train).float().cuda()
teacher = Net().cuda()
student = Net().cuda()
t_optimizer = optim.SGD(teacher.parameters(),lr=0.001, momentum=0.9)
s_optimizer = optim.SGD(student.parameters(),lr=0.001, momentum=0.9)
for iteration in range(200):
# subscript: t,s : teacher,student
# subscript: l,u : label,unlabel
teacher.train()
student.train()
t_optimizer.zero_grad()
s_optimizer.zero_grad()
random_sample = np.random.choice(num_unlabel*2,16).tolist()
u_x = torch.from_numpy(x_unlabel[random_sample]).float().cuda()
#prepare other input
#note these is detached, i.e. not used for backprop
#s_l_logit = student(l_x)
#s_l_loss = F.binary_cross_entropy_with_logits(s_l_logit.detach(), l_y)
t_u_logit = teacher(u_x)
pseudo_y = torch.sigmoid(t_u_logit) #t_u_logit>0).float()
#------
#train student : update student using pesudo label data only
s_u_logit = student(u_x)
s_u_loss = F.binary_cross_entropy_with_logits(s_u_logit, pseudo_y,reduction='none')
s_u_loss = s_u_loss[torch.abs(pseudo_y-0.5)>0.45].mean()
s_u_loss.backward()
s_optimizer.step()
#train teacher : update teacher using pesudo label data student change in loss
s_l_logit_new = student(l_x)
s_l_loss_new = F.binary_cross_entropy_with_logits(s_l_logit_new, l_y)
t_mpl_loss = s_l_loss_new
t_l_logit = teacher(l_x)
t_l_loss = F.binary_cross_entropy_with_logits(t_l_logit, l_y)
(t_l_loss+t_mpl_loss).backward()
#(t_mpl_loss).backward()
t_optimizer.step()
#print(iteration, s_u_loss.item(), t_l_loss.item(), t_mpl_loss.item())
#-------------------------------
print('metal pseudo label (exact) : student')
print_validate(student)
show_predict_space(student)
plt.title('metal pseudo label (exact): student')
plt.show()
print('metal pseudo label (exact): teacher')
print_validate(teacher)
show_predict_space(teacher)
plt.title('metal pseudo label (exact): teacher')
plt.show()
`