Is this loss function for knowledge distillation correct?

I have two networks: student and teacher. Both are being trained for a recognition task (against ground truth values, “targets”). Also I want that the student network should learn from the teacher network.

loss1 : takes care of student learning from ground truth values
loss2: takes care of teacher learning from ground truth values
dist_loss : Is this correct? I want to ensure that only the student learns from the teacher and not the opposite. Is detaching the output of teacher the correct way to prevent gradients of teacher network being updated from the dist_loss term?

def dist_loss(t, s):
    prob_t = F.softmax(t/T, dim=1)
    log_prob_s = F.log_softmax(s/T, dim=1)
    dist_loss = -(prob_t*log_prob_s).sum(dim=1).mean()
    return dist_loss

loss1 = F.cross_entropy(student_outputs, targets)
loss2 = F.cross_entropy(teacher_outputs, targets)
teacher_p = teacher_outputs.detach()
loss = loss1 + loss2 + dist_loss(teacher_p, student_outputs)
optimizer.zero_grad()
loss.backward()
optimizer.step()

Thank you

4 Likes

If you want to implement the regular offline knowledge distillation, there is no need to add the loss2, since the teacher should be trained already. The loss function for the student is the combination of teacher generated probabilities( high temperatured) and also the ground truth labels.

Agreed with Aryan. Also, I believe that you should balance, i.e. normalize, your losses rather than simply adding them up.

The standard practice is to pre-train your teacher, save it. And while training the student you take a linear combination of two losses as your loss function.

  1. A loss between the gold labels and student’s prediction
  2. A loss between the teacher’s prediction and the student’s prediction

This is the pattern followed by the TinyBERT paper. I’ve done this in practice as well and it works pretty decently.

1 Like

Exclude loss 2 in your overall calculation