Alright, so I am trying to train a ViT using CrossEntropyLoss like this:
import torch.nn.functional as F
# Train function
def train_combined(model, train_loader, criterion, optimizer, epochs=30):
model.train()
for epoch in range(epochs):
running_loss=0.0
correct=0
total=0
for image, label in train_loader:
image, label= image.to(device), label.to(device)
optimizer.zero_grad()
fwd_output=model(image)
loss=criterion(fwd_output, label)
loss.backward()
optimizer.step()
# Calculate accuracy
_, predicted = torch.max(fwd_output.data, 1)
total += label.size(0)
correct += (predicted == label).sum().item()
train_acc = 100 * correct / total
print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}, Accuracy: {train_acc:.2f}%")
When I do this, I get this message:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-64-79153cd7a2a3> in <cell line: 1>()
----> 1 train_combined(new_vit_pathmnist, val_pathmnist_dataloader, loss_fn, optimizer_fn)
4 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3477 if size_average is not None or reduce is not None:
3478 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3479 return torch._C._nn.cross_entropy_loss(
3480 input,
3481 target,
RuntimeError: 0D or 1D target tensor expected, multi-target not supported
How can I fix this?