Categorical cross entropy loss function equivalent in PyTorch

After a few hours of playing around, please see the following:

import tensorflow as tf
import torch

# Shape: (B, C)
y_true: torch.Tensor
y_pred: torch.Tensor

# All the below are equivalent
# Assumption: model outputs scores (not a Softmax or LogSoftmax)

# PyTorch 1.13.1
-(torch.nn.functional.log_softmax(y_pred, dim=1) * y_true).sum(dim=1).mean()
torch.nn.functional.cross_entropy(input=y_pred, target=y_true)

# TensorFlow 2.11.0
tf.reduce_mean(
    tf.keras.metrics.categorical_crossentropy(
        y_true, torch.nn.functional.softmax(y_pred, dim=1).detach().numpy()
    )
)
tf.reduce_mean(
    tf.keras.metrics.categorical_crossentropy(
        y_true,
        torch.nn.functional.log_softmax(y_pred, dim=1).detach().numpy(),
        from_logits=True,
    )
)
1 Like