Why detach probs in symmetric_kl function (deberta shift)

In the official implementation of sift by Microsoft/Deberta here, they implemented symmetric_kl function as:

def symmetric_kl(logits, target):
  logit_stu = logits.view(-1, logits.size(-1)).float()
  logit_tea = target.view(-1, target.size(-1)).float()
  logprob_stu = F.log_softmax(logit_stu, -1)
  logprob_tea = F.log_softmax(logit_tea, -1)
  prob_tea = logprob_tea.exp().detach()
  prob_stu = logprob_stu.exp().detach()
  floss = ((prob_tea*(-logprob_stu)).sum(-1))    # Cross Entropy
  bloss = ((prob_stu*(-logprob_tea)).sum(-1))    # Cross Entropy
  loss = floss + bloss
  return loss

My question is why do you need to detach() when computing prob_tea and prob_stu. I checked and found that its presence affects the gradients computed.

Hi Sohaib!

The Kullback-Leibler divergence is a particular measure of the dissimilarity
of two probability distributions. It is closely related to cross entropy. It takes
on its minimum of zero when the two probability distributions are the same.

The symmetric_kl() you are asking about is basically a symmetrized version,
kl (p1, p2) + kl (p2, p1).

floss and bloss are indeed the two cross entropies and they are added
together to get the symmetrized result.

Leaving aside the details of working in log space (that is, working with
log-probabilities rather than probabilities), let’s consider the cross entropy
of two probability distributions, p1 and p2.

For fixed p2, cross_entropy (p1, p2) takes on its minimum when p1
equals p2. But for fixed p1, it does not take on its minimum when p2 is
equal to p1. (Rather, for fixed p1, it takes on its minimum when p2 has
100% probability for its entry that corresponds to the largest probability
in p1 and 0% for all of its other entries.)

Consider floss: The part of its gradient due to logprob_stu drives the
probability distribution given (in log space) by logits to match target.
This is what we want. But, if we didn’t detach prob_tea, the part of
floss’s gradient due to prob_tea would drive target to concentrate all
of its probability in the single element that corresponded to the element
of logits with the largest probability. This is not what we want and would
work at cross-purposes to the logprob_stu part of the gradient. This is
why we detach prob_tea. This causes prob_tea’s gradient – which we
don’t want – to be thrown away.

The same story applies to bloss, but with logits and target switched.
For bloss, the part of its gradient due to prob_stu is unhelpful, so we
detach prob_stu.

Best.

K. Frank

1 Like