I want to compute sum of cross entropy over all classes for each prediction, where the input is batch (size n), and the output is batch (size n).

The simplest way is for loop (for 1000 classes):

def sum_of_CE_lost(input):
CE = torch.nn.CrossEntropyLoss()
L = 0
for x in range(1000):
L = L + CE(logit_unknown,tensor( np.ones( input.shape[0] ) *x , dtype=torch.int64 ).cuda() )
loss = L.sum()
return loss

However, it is very slow. What is a better way? How can we parallelized it for GPU (CUDA)?

input is prediction of imagenet_resnet50 on bathc, i.e., shape of input is [batch_size, number-of_classes]

Yes, this is straightforward, based on the definition of cross_entropy.

Let input be a tensor of shape (nBatch, nClass). Then (mentally
ignoring the minor errors in your code such as using the class version
of CrossEntropyLoss as if it were a function and c not being a tensor)
the following will match the result of the code you posted:

(We divide by input.shape[0] because cross_entropy() takes, by
default the mean across the batch dimension.)

Because this expression uses pytorch tensor functions, you will
automatically get the benefit of pytorch’s gpu support (if you move
your tensors to the gpu) (as well as autograd, if you care).

(As an aside, though, it’s not really clear to me that your calculation
makes a lot of sense. But you can compute it with pytorch tensors
on the gpu without explicit loops.)

The input is not shape of ‘(n batch, nClass)’. The input is in shape (‘nBatch, mLogit’). i.e., each row has 1000 elements, corresponding to each classes.

Whether we call this number, 1000, nClass or mLogit is immaterial.
I maintain that my expression reproduces the result of your (corrected)
pseudo-code. Have you tried it?

If you come to the contrary conclusion, please post a complete,
runnable sample script that illustrates your contrary result. (You
can generate random sample data for you input tensor, but please
trim down the number 1000 to something more manageable like
5 or 10 to make everyone’s life easier.)

Also, to get our vocabulary straight, what do you mean by “class”
when you say “sum over all classes?” How many classes do you
have in your example? Should I understand the value of the quantity
“C” in the image of the equation you posted to be your number of
classes (and equal to 1000 in your example)?