Loss function for One-hot encoding

I am trying to implement MNIST with one hot encoding but cross entropy won’t work. What loss function should I be using?

what if you take argmax of the one hot vector and pass it to cross entropy loss?

But wouldn’t the argmax of every one-hot-encoded vector be 1?

No. argmax gives the index of max element

Taking the argmax worked. Thanks
But I was looking at the cifar-10 tutorial of Pytorch and it had an output layer of width 10 but the target was a scalar only. How does that work?

the cross entropy loss function internally takes care of this.

Can you please explain how that happens or point to any resource?