Hi, I am currently training a NN to classify inputs into 199 categories using pytorch cross_entropy loss. If I try using the probabilities for each class as the targets and cast the targets to float:
def cross_entropy_loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input=None):
_labels = F.one_hot(_labels.long(),receiver_output.shape[-1])
loss = F.cross_entropy(receiver_output.squeeze(), _labels.float(), reduction='none',label_smoothing=0.1)
return loss, {}
I get the output for one epoch:
{"loss": 5.204253673553467, "baseline": 5.205080986022949, "sender_entropy": 0.001109135104343295, "receiver_entropy": 0.0, "mode": "train", "epoch": 1}
and then the following error:
<ipython-input-91-936be60121ba> in cross_entropy_loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input)
5 def cross_entropy_loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input=None):
6 _labels = F.one_hot(_labels.long(),receiver_output.shape[-1])
----> 7 loss = F.cross_entropy(receiver_output.squeeze(), _labels.float(), reduction='none',label_smoothing=0.1)
8 return loss, {}
~\Anaconda3\envs\egg36\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
2844 if size_average is not None or reduce is not None:
2845 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
2847
2848
RuntimeError: expected scalar type Long but found Float
If, on the other hand, I cast the target to long:
def cross_entropy_loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input=None):
_labels = F.one_hot(_labels.long(),receiver_output.shape[-1])
loss = F.cross_entropy(receiver_output.squeeze(), _labels.long(), reduction='none',label_smoothing=0.1)
return loss, {}
I inmediately get the following error:
<ipython-input-94-e319bc27aeaa> in cross_entropy_loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input)
5 def cross_entropy_loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input=None):
6 _labels = F.one_hot(_labels.long(),receiver_output.shape[-1])
----> 7 loss = F.cross_entropy(receiver_output.squeeze(), _labels.long(), reduction='none',label_smoothing=0.1)
8 return loss, {}
~\Anaconda3\envs\egg36\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
2844 if size_average is not None or reduce is not None:
2845 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
2847
2848
RuntimeError: Expected floating point type for target with class probabilities, got Long
As you can see, in both cases the error is being triggered by the loss function I defined (in line 7), and internally in line 2846 of the torch.nn.functional. And it is asking for contradictory target formats.
If I change the loss function to work with index targets instead of probabilities:
def cross_entropy_loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input=None):
#_labels = F.one_hot(_labels.long(),receiver_output.shape[-1])
loss = F.cross_entropy(receiver_output.squeeze(), _labels.long(), reduction='none',label_smoothing=0.1)
return loss, {}
I also get the first batch output, but after that I get this error (in the same line as the others):
~\Anaconda3\envs\egg36\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
2844 if size_average is not None or reduce is not None:
2845 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
2847
2848
RuntimeError: Expected target size [32, 199], got [32]
Which acording to the documentation should not be a problem for index targets:
I am completely at a loss of what is happening here, I tried the same exact loss with reduction='mean'
and it worked fine, so maybe the issue is there. Still I need this part to have reduction='none'
so I have to fix it.