I tried to run the example of CrossEntropyLoss in the picture and there was an error
The code is as follows
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
output.backward()
The error is as follows
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_32/2605795972.py in <module>
9 input = torch.randn(3, 5, requires_grad=True)
10 target = torch.randn(3, 5).softmax(dim=1)
---> 11 output = loss(input, target)
12 output.backward()
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
1119 def forward(self, input: Tensor, target: Tensor) -> Tensor:
1120 return F.cross_entropy(input, target, weight=self.weight,
-> 1121 ignore_index=self.ignore_index, reduction=self.reduction)
1122
1123
/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2822 if size_average is not None or reduce is not None:
2823 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2825
2826
RuntimeError: 1D target tensor expected, multi-target not supported
question:
1) Obviously, the error is reported in line 11. Does this indicate that there is an error in this example?
2) In the classification problem,does the label data structure of the input crossentropyloss() need to be [n, 1]?