Loss function in pytorch

Hi,
I have converted the loss function from tensorflow to torch. The code is in below. The original code is in tensorflow. The tensorflow version is working. I want to know that the pytorch version is correct or not as I just replaced the equivalent function in pytorch. Thanks in advance.

Pytorch version:

def kl_loss_compute(logits1, logits2):
    pred1 = torch.nn.functional.softmax(logits1, dim=1)
    pred2 = torch.nn.functional.softmax(logits2, dim=1)
    loss = torch.mean(torch.sum(pred2 * torch.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))
    return loss

Tensorflow version:

def kl_loss_compute(logits1, logits2):
    pred1 = tf.nn.softmax(logits1)
    pred2 = tf.nn.softmax(logits2)
    loss = tf.reduce_mean(tf.reduce_sum(pred2 * tf.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))

    return loss

The method looks alright, if your input tensors are 2-dimensional, since tf.nn.softmax seems to be using the last axis by default.
Did you compare both functions using some random input?