As per my understanding, I wrote this piece of code. To load images and predict.
data_transforms = {
'predict': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
dataset = {'predict' : datasets.ImageFolder("./data", data_transforms['predict'])}
dataloader = {'predict': torch.utils.data.DataLoader(dataset['predict'], batch_size = 1, shuffle=False, num_workers=4)}
outputs = list()
since = time.time()
for inputs, labels in dataloader['predict']:
inputs = inputs.to(device)
output = model(inputs)
output = output.to(device)
index = output.data.numpy().argmax()
print index
I hope this helps you and solves your problem