Strange pair of errors using CrossEntropyLoss

I’m running into a very strange pair of errors when trying to calculate the Cross Entropy Loss. First I got this error message:

D:\anaconda3\envs\downgrade\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2844     if size_average is not None or reduce is not None:
   2845         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   2847 
   2848 

RuntimeError: "log_softmax_lastdim_kernel_impl" not implemented for 'Long'

Now, it seems obvious that of course the softmax function isn’t going to work on a long, it should be a float. So I converted the values to float, and I got this error message:

D:\anaconda3\envs\downgrade\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   2844     if size_average is not None or reduce is not None:
   2845         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   2847 
   2848 

RuntimeError: expected scalar type Long but found Float

“Expected long but found float”? So should it be a long or a float? this trace log is pretty deep in the library, but points to the exact same line of code both times. How could these two errors, on the same line of code, simultaneously require me to be both a Long and a Float? Why can’t it just implicitly convert the values?
Nathan

Turns out they should both be Long, but the two trace logs I showed are pointing to two (very similar) lines of code.

Can you please explain how you resolved this issue? We will still get the first Runtime Error if we leave it as long, right?

Hello,

In my particular case, the inputs should be float while the targets should be converted to long.

Nathan