Hi,
I have converted the loss function from tensorflow to torch. The code is in below. The original code is in tensorflow. The tensorflow version is working. I want to know that the pytorch version is correct or not as I just replaced the equivalent function in pytorch. Thanks in advance.
Pytorch version:
def kl_loss_compute(logits1, logits2):
pred1 = torch.nn.functional.softmax(logits1, dim=1)
pred2 = torch.nn.functional.softmax(logits2, dim=1)
loss = torch.mean(torch.sum(pred2 * torch.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))
return loss
Tensorflow version:
def kl_loss_compute(logits1, logits2):
pred1 = tf.nn.softmax(logits1)
pred2 = tf.nn.softmax(logits2)
loss = tf.reduce_mean(tf.reduce_sum(pred2 * tf.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))
return loss