KL divergence negative with AMP

I’m computing the KL divergence between two categorical distributions using torch.distributions.kl.kl_divergence. When using AMP, the result for my particular inputs is frequently negative (>40% of cases), but only very rarely (<1%) when not using AMP. According to the autocast-op-reference however, kl_div should autocast to float32 anyway. I’m wondering whether this is an oversight and only applies to torch.nn.functional.kl_div and not to torch.distributions.kl.kl_divergence?

Could you post a minimal, executable code snippet showing the usage of torch.distributions.kl.kl_divergence so that we could check which internal dtypes are used?
I’m currently unsure if this particular op is causing the issues or if its input is already problematic.

from torch.distributions import Categorical, kl_divergence
from torch.nn.functional import kl_div, log_softmax
import torch

l1 = torch.tensor([[1,2,3]], device='cuda', dtype=torch.float16)
l2 = torch.tensor([[1.01,2,3]], device='cuda', dtype=torch.float16)

print('NO AMP:')
print(kl_divergence(Categorical(logits=l1), 
                    Categorical(logits=l2)).item())
print(kl_div(log_softmax(l2, dim=-1), 
             log_softmax(l1, dim=-1), 
             reduction='batchmean', log_target=True).item())
print(Categorical(logits=l1).logits.dtype, 
      log_softmax(l1, dim=-1).dtype)


with torch.cuda.amp.autocast():
    print('\nAMP:')
    print(kl_divergence(Categorical(logits=l1), 
                        Categorical(logits=l2)).item())
    print(kl_div(log_softmax(l2, dim=-1), 
                 log_softmax(l1, dim=-1), 
                 reduction='batchmean', log_target=True).item())
    print(Categorical(logits=l1).logits.dtype, 
          log_softmax(l1, dim=-1).dtype)


NO AMP:
-0.0008792877197265625
-0.00015211105346679688
torch.float16 torch.float16

AMP:
-0.0008792047156020999
3.946683136746287e-06
torch.float16 torch.float32

(correct kl-div should be ~4.1e-06)

You’re right. It seems that neither result is correct when the inputs are already in float16. When using autocast, log_softmax casts to float32, that’s why it seemed that nn.functional.kl_div was correct. For this example, the results from functional.kl_div are always better than distributions.kl_divergence, whether inputs are both in float32, float16, or mixed.

I suppose in an autocast region:

  • Categorical should compute logits in float32 (like log_softmax)?
  • and both kl-div functions should cast inputs to float32?

Thanks for the update and the great debugging!
Indeed it looks strange that kl_divergence(Categorical()) seems to be giving a worse numerical stability than the kl_div counterpart. We’ll check how the internal implementations are used.

1 Like