How should I implement cross-entropy loss with continuous target outputs?

The following code should work in PyTorch 0.2:

def cross_entropy(pred, soft_targets):
    logsoftmax = nn.LogSoftmax()
    return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1))

assuming pred and soft_targets are both Variables with shape (batchsize, num_of_classes), each row of pred is predicted logits and each row of soft_targets is a discrete distribution.

17 Likes