CrossEntropy Issue

Alright, so I am trying to train a ViT using CrossEntropyLoss like this:

import torch.nn.functional as F

# Train function
def train_combined(model, train_loader, criterion, optimizer, epochs=30):
  model.train()
  for epoch in range(epochs):
    running_loss=0.0
    correct=0
    total=0
    for image, label in train_loader:
      image, label= image.to(device), label.to(device)
      optimizer.zero_grad()

      fwd_output=model(image)
      loss=criterion(fwd_output, label)

      loss.backward()
      optimizer.step()

      # Calculate accuracy
      _, predicted = torch.max(fwd_output.data, 1)
      total += label.size(0)
      correct += (predicted == label).sum().item()

    train_acc = 100 * correct / total
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {train_acc:.2f}%")

When I do this, I get this message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-64-79153cd7a2a3> in <cell line: 1>()
----> 1 train_combined(new_vit_pathmnist, val_pathmnist_dataloader, loss_fn, optimizer_fn)

4 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3477     if size_average is not None or reduce is not None:
   3478         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3479     return torch._C._nn.cross_entropy_loss(
   3480         input,
   3481         target,

RuntimeError: 0D or 1D target tensor expected, multi-target not supported

How can I fix this?

Your target is supposed to contain class indices and have the same shape as the model output without the class dimension. Based on the error message it seems your target contains an additional dimension.