I’m computing the KL divergence between two categorical distributions using torch.distributions.kl.kl_divergence
. When using AMP, the result for my particular inputs is frequently negative (>40% of cases), but only very rarely (<1%) when not using AMP. According to the autocast-op-reference however, kl_div
should autocast to float32
anyway. I’m wondering whether this is an oversight and only applies to torch.nn.functional.kl_div
and not to torch.distributions.kl.kl_divergence
?
Could you post a minimal, executable code snippet showing the usage of torch.distributions.kl.kl_divergence
so that we could check which internal dtype
s are used?
I’m currently unsure if this particular op is causing the issues or if its input is already problematic.
from torch.distributions import Categorical, kl_divergence
from torch.nn.functional import kl_div, log_softmax
import torch
l1 = torch.tensor([[1,2,3]], device='cuda', dtype=torch.float16)
l2 = torch.tensor([[1.01,2,3]], device='cuda', dtype=torch.float16)
print('NO AMP:')
print(kl_divergence(Categorical(logits=l1),
Categorical(logits=l2)).item())
print(kl_div(log_softmax(l2, dim=-1),
log_softmax(l1, dim=-1),
reduction='batchmean', log_target=True).item())
print(Categorical(logits=l1).logits.dtype,
log_softmax(l1, dim=-1).dtype)
with torch.cuda.amp.autocast():
print('\nAMP:')
print(kl_divergence(Categorical(logits=l1),
Categorical(logits=l2)).item())
print(kl_div(log_softmax(l2, dim=-1),
log_softmax(l1, dim=-1),
reduction='batchmean', log_target=True).item())
print(Categorical(logits=l1).logits.dtype,
log_softmax(l1, dim=-1).dtype)
NO AMP:
-0.0008792877197265625
-0.00015211105346679688
torch.float16 torch.float16
AMP:
-0.0008792047156020999
3.946683136746287e-06
torch.float16 torch.float32
(correct kl-div should be ~4.1e-06)
You’re right. It seems that neither result is correct when the inputs are already in float16
. When using autocast, log_softmax
casts to float32
, that’s why it seemed that nn.functional.kl_div
was correct. For this example, the results from functional.kl_div
are always better than distributions.kl_divergence
, whether inputs are both in float32
, float16
, or mixed.
I suppose in an autocast region:
-
Categorical
should computelogits
infloat32
(likelog_softmax
)? - and both kl-div functions should cast inputs to
float32
?
Thanks for the update and the great debugging!
Indeed it looks strange that kl_divergence(Categorical())
seems to be giving a worse numerical stability than the kl_div
counterpart. We’ll check how the internal implementations are used.