Hi, I would like to understand the cross entropy loss for multi-dimensional input by implementing it by myself. Currently I am able to get the close result by iterating using np.ndindex,
K = 10
X = torch.randn(32, K, 20, 30)
t = torch.randint(K, (32, 20, 30)).long()
w = torch.randn(K).abs()
CrossEntropyLoss(weight=w, reduction="none")(X, t)
# this is same as
XX = X.movedim(1, -1)
Ls = torch.zeros_like(t, dtype=torch.float)
for ind in np.ndindex(32, 20, 30):
label = t[ind]
loss = -w[label] * torch.log(XX[ind][label].exp() / XX[ind].exp().sum())
Ls[ind] = loss
Ls
But this involves looping which I assume would be slower than a direct vectorized computation. Therefore I am wondering if there are some tricks I can do to implement in plain torch. Thanks!
You can get a tensor of weights corresponding to the class labels
in t by indexing into w, and you can get a tensor of (negative)
log-probabilities by calling .take_along_dim() with t as the
argument:
you are computing log-softmax() “by hand” in the straightforward way,
but it is numerically more stable to use the log-sum-exp trick, either
by implementing it explicitly or by using pytorch’s log_softmax()
function (as I have done above).