# Cross entropy for multi-dimensional implementation

Hi, I would like to understand the cross entropy loss for multi-dimensional input by implementing it by myself. Currently I am able to get the close result by iterating using `np.ndindex`,

``````K = 10
X = torch.randn(32, K, 20, 30)
t = torch.randint(K, (32, 20, 30)).long()
w = torch.randn(K).abs()

CrossEntropyLoss(weight=w, reduction="none")(X, t)

# this is same as

XX = X.movedim(1, -1)
Ls = torch.zeros_like(t, dtype=torch.float)
for ind in np.ndindex(32, 20, 30):
label = t[ind]
loss = -w[label] * torch.log(XX[ind][label].exp() / XX[ind].exp().sum())
Ls[ind] = loss
Ls
``````

But this involves looping which I assume would be slower than a direct vectorized computation. Therefore I am wondering if there are some tricks I can do to implement in plain torch. Thanks!

Any ideas? I am looking for a more clean and better way to replace the loop. Thanks!

Hi Chris!

You can get a tensor of weights corresponding to the class labels
in `t` by indexing into `w`, and you can get a tensor of (negative)
log-probabilities by calling `.take_along_dim()` with `t` as the
argument:

``````>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> K = 10
>>> X = torch.randn(32, K, 20, 30)
>>> t = torch.randint(K, (32, 20, 30)).long()
>>> w = torch.randn(K).abs()
>>> xeA = torch.nn.CrossEntropyLoss (weight = w, reduction = 'none') (X, t)
>>> xeB = w[t] * -X.log_softmax (1).take_along_dim (t.unsqueeze (1), 1).squeeze (1)
>>> torch.allclose (xeA, xeB)
True
>>> torch.equal (xeA, xeB)
True
``````

As an aside:

you are computing log-softmax() â€śby handâ€ť in the straightforward way,
but it is numerically more stable to use the log-sum-exp trick, either
by implementing it explicitly or by using pytorchâ€™s `log_softmax()`
function (as I have done above).

Best.

K. Frank

1 Like

Thanks! This is a clever way! I didnâ€™t know there was a `take_along_dim` function beforehand. Thanks a lot for pointing out the log_softmax trick!