Cross entropy for multi-dimensional implementation

Hi, I would like to understand the cross entropy loss for multi-dimensional input by implementing it by myself. Currently I am able to get the close result by iterating using np.ndindex,

K = 10
X = torch.randn(32, K, 20, 30)
t = torch.randint(K, (32, 20, 30)).long()
w = torch.randn(K).abs()

CrossEntropyLoss(weight=w, reduction="none")(X, t)

# this is same as

XX = X.movedim(1, -1)
Ls = torch.zeros_like(t, dtype=torch.float)
for ind in np.ndindex(32, 20, 30):
    label = t[ind]
    loss = -w[label] * torch.log(XX[ind][label].exp() / XX[ind].exp().sum())
    Ls[ind] = loss
Ls

But this involves looping which I assume would be slower than a direct vectorized computation. Therefore I am wondering if there are some tricks I can do to implement in plain torch. Thanks!

Any ideas? I am looking for a more clean and better way to replace the loop. Thanks!

Hi Chris!

You can get a tensor of weights corresponding to the class labels
in t by indexing into w, and you can get a tensor of (negative)
log-probabilities by calling .take_along_dim() with t as the
argument:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> K = 10
>>> X = torch.randn(32, K, 20, 30)
>>> t = torch.randint(K, (32, 20, 30)).long()
>>> w = torch.randn(K).abs()
>>> xeA = torch.nn.CrossEntropyLoss (weight = w, reduction = 'none') (X, t)
>>> xeB = w[t] * -X.log_softmax (1).take_along_dim (t.unsqueeze (1), 1).squeeze (1)
>>> torch.allclose (xeA, xeB)
True
>>> torch.equal (xeA, xeB)
True

As an aside:

you are computing log-softmax() “by hand” in the straightforward way,
but it is numerically more stable to use the log-sum-exp trick, either
by implementing it explicitly or by using pytorch’s log_softmax()
function (as I have done above).

Best.

K. Frank

1 Like

Thanks! This is a clever way! I didn’t know there was a take_along_dim function beforehand. Thanks a lot for pointing out the log_softmax trick!