I am trying to use torchvision.models.vit_b_32(). However, when I pass in some arbitrary data I get all zeros in the output
import torch
import torchvision
import torchvision.transforms as transforms
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
batch_size = 4
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
newmodel = torchvision.models.vit_b_32().to(device)
images, labels = next(iter(testloader))
images = images.to(device)
labels = labels.to(device)
print(torch.sum(newmodel(images)[0]))
Why is the last print statement printing zero?