RuntimeError: 0D or 1D target tensor expected, multi-target not supported in Pytorch

I would like to do binary classification with softmax in Pytorch. Even though I set the number of output as 2 and use “nn.CrossEntropyLoss()”, I am getting the following error:

RuntimeError: 0D or 1D target tensor expected, multi-target not supported
train_loss = []

for epoch in range(epochs):
    
    y_pred = model(x_train)  
    loss = loss_func(y_pred, y_tensor)  
    optimizer.zero_grad()         
    loss.backward()               
    optimizer.step()             
        
    train_loss.append(loss.item())

The size I am getting from y_pred and y_tensor is torch.Size([3000, 2]) torch.Size([3000, 1]). How this issue can be solved?

thanks in advance

I think your GT y_tensor should be of shape [3000] not [3000, 1]

>>> pred = torch.randn(10,2)
>>> gt = torch.randint(0,2,(10,))
>>> nn.CrossEntropyLoss()(pred,gt)
tensor(0.8060)
>>> gt = torch.randint(0,2,(10,1))
>>> nn.CrossEntropyLoss()(pred,gt)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\anaconda3\envs\taichi\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\anaconda3\envs\taichi\lib\site-packages\torch\nn\modules\loss.py", line 1152, in forward
    label_smoothing=self.label_smoothing)
  File "D:\anaconda3\envs\taichi\lib\site-packages\torch\nn\functional.py", line 2846, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: 0D or 1D target tensor expected, multi-target not supported

I get the
following error if I make the change to y_tensor:

IndexError: Target 4 is out of bounds.

I have two classes. One is labeled as 0 the other labeled as 4. Do I have to change the label to 1 or there is a better way to deal with this error?

Yes, please change 4 to 1 because the last dimension of your prediction output is 2
You can achieve it by y_tensor[y_tensor == 4] = 1 (or modify it in the dataloader)