Since knowledge distillation seems to work on the softmax of the logits, how can we apply knowledge distillation in a multi-label setting where generally a sigmoid is used to obtain the probability of each label?
Not 100% sure, but I think the following should work.
loss for multi class classification task
alpha * loss_fn(student_model_output, one_hot_encoded_vectors) + (1 - alpha) * kl_divergence_loss(teacher_model_output, student_model_output)
Since in multi label setup each output is a probability in itself. I think the modified loss with knowledge distillation should be.
If n = number of outputs
alpha * loss_fn(student_model_output, hard encoded labels) +
(1-alpha) * 1/n * sum(kl_divergence_loss(student_output[output_idx], teacher_output[output_idx]) for output_idx in outputs)