I was converting the following tensorflow code to pytorch,
import tensorflow.compat.v1 as tf
logit_true = tf.distributions.Categorical(probs=logit_true)
logit_aug = tf.distributions.Categorical(probs=logit_aug)
distillation_loss = tf.distributions.kl_divergence(logit_true,logit_aug,allow_nan_stats= False)
My pytorch implementation.
logit_true = torch.distributions.categorical.Categorical(probs=logit_true)
logit_aug = torch.distributions.categorical.Categorical(probs=logit_aug)
distillation_loss = torch.distributions.kl.kl_divergence(logit_true,logit_aug)
However, the model with tensorflow runs fine, but the one with my pytorch implementation somehow messes up . I wanted to know if there was any difference in both of the kl divergence implementations.
A little context: It’s a part of distillation loss
Pytorch vs Tensorflow: