Change:
labels = labels.to(device)
to
labels = labels.to(device=device, dtype=torch.int64)
CrossEntropyLoss expects targets to be class indices, and hence Long Tensors, but you gave it a Float Tensor.
Change:
labels = labels.to(device)
to
labels = labels.to(device=device, dtype=torch.int64)
CrossEntropyLoss expects targets to be class indices, and hence Long Tensors, but you gave it a Float Tensor.