Why cross entropy loss has a separate input format for pixel level classification?

I was working with cross entropy loss for a sequence labeling problem. Just like pixel-level classification, sequence labeling predicts the class for every token in the sequence.

I read in the documentation of cross entropy loss that there is a separate format for such fine-grained classification.

Let us stick to class probabilities in the ground-truth, NOT class indices.

The standard input/output format is:

N x C

Whereas, input/output format for multidimensional data or fine-grained classification is:

N x C x d_1 x d_2 x … x d_k

Where N is Batch size or number of training examples, C is number of classes, and d_1 … d_k are the dimensions of the input (in case of image k=2, in case of text sequence k=1).

But I am thinking why is it even needed? Because we can modify our problem to the standard format. How? Let us take pixel level classification only. We can simply assume each pixel is our training example, instead of an image. That’s it.

I have tested, this approach works.

Here is a working code:

Let us take classes (C) =3, image size (d1 x d2) = 3x3, and number of images (N) = 3

Let us try with the separate input format.

import torch
loss = torch.nn.CrossEntropyLoss()
predictions = torch.tensor([ [
                              [[1.,10.,1.], [1.,1.,10.], [1.,1.,1.]], 
                              [[10.,1.,10.], [10.,1.,1.], [1.,10.,1.]], 
                              [[1.,1. ,1.], [1.,10.,1.], [10.,1.,10.]]
                              ] , 
                            [
                             [[1.,1.,1.], [1.,10.,1.], [10.,1.,1.]], 
                             [[10.,10,1], [1,1,1], [1,1,10]],
                             [[1.,1.,10.], [10.,1.,10.], [1.,10.,1.]]
                             ] ],requires_grad=True) # NxCx3x3 = 2x3x3x3
target = torch.tensor([ [
                              [[0.,1.,0.], [0.,0.,1.], [0,0,0]], 
                              [[1.,0.,1.], [1.,0,0], [0.,1.,0.]], 
                              [[0,0,0], [0.,1.,0.], [1.,0.,1.]]
                              ] , 
                            [
                             [[0,0,0.], [0.,1.,0.], [1.,0,0.]], 
                             [[1.,1,0], [0,0,0], [0,0.,1]],
                             [[0,0.,1.], [1.,0.,1.], [0.,1.,0.]]
                             ] ]) # NxCx3x3 = 2x3x3x3
output = loss(predictions, target)
print(output)

Now, let us test the same thing by permuting the axes i.e. we are going to treat each pixel as an independent training example.

import torch
loss = torch.nn.CrossEntropyLoss()
predictions = torch.tensor([ [
                              [[1.,10.,1.], [1.,1.,10.], [1.,1.,1.]], 
                              [[10.,1.,10.], [10.,1.,1.], [1.,10.,1.]], 
                              [[1.,1. ,1.], [1.,10.,1.], [10.,1.,10.]]
                              ] , 
                            [
                             [[1.,1.,1.], [1.,10.,1.], [10.,1.,1.]], 
                             [[10.,10,1], [1,1,1], [1,1,10]],
                             [[1.,1.,10.], [10.,1.,10.], [1.,10.,1.]]
                             ] ],requires_grad=True) # NxCx3x3 = 2x3x3x3

predictions = predictions.permute(0,2,3,1).reshape(-1,3) # (-1,C)

target = torch.tensor([ [
                              [[0.,1.,0.], [0.,0.,1.], [0,0,0]], 
                              [[1.,0.,1.], [1.,0,0], [0.,1.,0.]], 
                              [[0,0,0], [0.,1.,0.], [1.,0.,1.]]
                              ] , 
                            [
                             [[0,0,0.], [0.,1.,0.], [1.,0,0.]], 
                             [[1.,1,0], [0,0,0], [0,0.,1]],
                             [[0,0.,1.], [1.,0.,1.], [0.,1.,0.]]
                             ] ]) # NxCx3x3 = 2x3x3x3

target = target.permute(0,2,3,1).reshape(-1,3) # (-1,C)
output = loss(predictions, target)
print(output)

Both gives the same value. Same thing can be done with class indices. Then why not ask users to permute their tensors first before passing them to the loss functions.

Or Is it just to make the life easy? Under the hood same procedure of permuting is happening. I couldn’t comprehend the source code for cross entropy (it refers to some torch._C.nn.cross_entropy), that is why I am not sure of this.

The main reason would be for convenience and to avoid code duplication, which could add errors to user code. The same applies for e.g. weighted losses, as you could simply multiply the unreduced loss with a weight tensor and reduce it afterwards. However, would you remember that the reduction would divide by the used weights or would you just call mean() on the weighted loss?
A similar approach is also possible for e.g. grouped convolutions and one can permute the channel blocks to the batch dim, but this is also too error prone and it makes sense to provide these methods to the user.

so basically convenience to the user, isnt it?

No, not only convenience, but also avoiding code duplication and thus errors.