Soft Cross Entropy Loss (TF has it does Pytorch have it)

Hello Raaj!

I do not believe that pytorch has a “soft” cross-entropy function built in.
But you can implement it using pytorch tensor operations, so you should
get the full benefit of autograd and gpu acceleration.

See this (pytorch version 0.3.0) script:

import torch
torch.__version__

# define "soft" cross-entropy with pytorch tensor operations
def softXEnt (input, target):
    logprobs = torch.nn.functional.log_softmax (input, dim = 1)
    return  -(target * logprobs).sum() / input.shape[0]

torch.manual_seed (2020)

# input values are logits
input  = torch.autograd.Variable (torch.randn ((2, 5)))
# target values are "soft" probabilities that sum to one (for each sample in batch)
target = torch.nn.functional.softmax (torch.autograd.Variable (torch.randn ((2, 5))), dim = 1)

input
target
softXEnt (input, target)

# make "hard" categorical target
dummy, target_cat = target.max (dim = 1)
# make "hard" one-hot target
target_onehot = torch.zeros_like (target).scatter (1, target_cat.unsqueeze (1), 1)

target_cat
target_onehot
# check that softXEnt agrees with pytorch's cross_entropy for "hard" case
torch.nn.functional.cross_entropy (input, target_cat)
softXEnt (input, target_onehot)

Here is the output:

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> # define "soft" cross-entropy with pytorch tensor operations
... def softXEnt (input, target):
...     logprobs = torch.nn.functional.log_softmax (input, dim = 1)
...     return  -(target * logprobs).sum() / input.shape[0]
...
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x000001F948906630>
>>>
>>> # input values are logits
... input  = torch.autograd.Variable (torch.randn ((2, 5)))
>>> # target values are "soft" probabilities that sum to one (for each sample in batch)
... target = torch.nn.functional.softmax (torch.autograd.Variable (torch.randn ((2, 5))), dim = 1)
>>>
>>> input
Variable containing:
 1.2372 -0.9604  1.5415 -0.4079  0.8806
 0.0529  0.0751  0.4777 -0.6759 -2.1489
[torch.FloatTensor of size 2x5]

>>> target
Variable containing:
 0.0629  0.1508  0.5417  0.1899  0.0547
 0.0867  0.0389  0.0408  0.0659  0.7677
[torch.FloatTensor of size 2x5]

>>> softXEnt (input, target)
Variable containing:
 2.4262
[torch.FloatTensor of size 1]

>>>
>>> # make "hard" categorical target
... dummy, target_cat = target.max (dim = 1)
>>> # make "hard" one-hot target
... target_onehot = torch.zeros_like (target).scatter (1, target_cat.unsqueeze (1), 1)
>>>
>>> target_cat
Variable containing:
 2
 4
[torch.LongTensor of size 2]

>>> target_onehot
Variable containing:
 0  0  1  0  0
 0  0  0  0  1
[torch.FloatTensor of size 2x5]

>>> # check that softXEnt agrees with pytorch's cross_entropy for "hard" case
... torch.nn.functional.cross_entropy (input, target_cat)
Variable containing:
 2.2656
[torch.FloatTensor of size 1]

>>> softXEnt (input, target_onehot)
Variable containing:
 2.2656
[torch.FloatTensor of size 1]

Good luck!

K. Frank

7 Likes