Regarding KL divergence in pytorch (vs Tensorflow)

I was converting the following tensorflow code to pytorch,

import tensorflow.compat.v1 as tf

logit_true = tf.distributions.Categorical(probs=logit_true)
logit_aug = tf.distributions.Categorical(probs=logit_aug)

distillation_loss = tf.distributions.kl_divergence(logit_true,logit_aug,allow_nan_stats= False)

My pytorch implementation.

logit_true = torch.distributions.categorical.Categorical(probs=logit_true)
logit_aug = torch.distributions.categorical.Categorical(probs=logit_aug)

distillation_loss = torch.distributions.kl.kl_divergence(logit_true,logit_aug)

However, the model with tensorflow runs fine, but the one with my pytorch implementation somehow messes up . I wanted to know if there was any difference in both of the kl divergence implementations.

A little context: It’s a part of distillation loss

Pytorch vs Tensorflow: