I have two networks: student and teacher. Both are being trained for a recognition task (against ground truth values, “targets”). Also I want that the student network should learn from the teacher network.
loss1 : takes care of student learning from ground truth values
loss2: takes care of teacher learning from ground truth values
dist_loss : Is this correct? I want to ensure that only the student learns from the teacher and not the opposite. Is detaching the output of teacher the correct way to prevent gradients of teacher network being updated from the dist_loss term?
def dist_loss(t, s):
prob_t = F.softmax(t/T, dim=1)
log_prob_s = F.log_softmax(s/T, dim=1)
dist_loss = -(prob_t*log_prob_s).sum(dim=1).mean()
return dist_loss
loss1 = F.cross_entropy(student_outputs, targets)
loss2 = F.cross_entropy(teacher_outputs, targets)
teacher_p = teacher_outputs.detach()
loss = loss1 + loss2 + dist_loss(teacher_p, student_outputs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Thank you