Hi! I am now trying to measure some baseline numbers of models on ImageNet ILSVRC2012, but weirdly I cannot use pretrained models to reproduce high accuracies even on the train set. It seems my preprocessing is correct. Can you please point out what goes wrong my codes? Thank you very much!
import numpy as np
import torch
import torchvision
from tqdm import tqdm
from torchvision import models
from torchvision import transforms
baseline_imagenet_dataset = torchvision.datasets.ImageNet(root = "../datasets/ImageNet/train/",
split="train",
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
)
baseline_imagenet_loader = torch.utils.data.DataLoader(baseline_imagenet_dataset,
shuffle= False,
batch_size = 2048,
num_workers = 32,
prefetch_factor=8)
def evaluate_dataparallel_model(model, dataloader):
predictions = []
match_list = []
with torch.no_grad():
for img_batch, labels in tqdm(dataloader):
img_batch_gpu = img_batch.cuda()
labels_gpu = labels.cuda()
preds = model(img_batch_gpu) # output (batch_size, 1000)
digit_preds = torch.argmax(preds, dim = 1)
matches = labels_gpu == digit_preds
predictions.append(digit_preds)
match_list.append(matches)
predictions = torch.cat(predictions)
matches = torch.cat(match_list)
precision = matches.sum() / matches.shape[0]
return predictions.cpu(), matches.cpu(), precision.cpu()
resnet50 = models.resnet50(pretrained=True)
resnet50 = torch.nn.DataParallel(resnet50).cuda()
pred, m, precis = evaluate_dataparallel_model(resnet50, baseline_imagenet_loader)
print(precis)
the output during running is
72%|███████▏ | 452/626 [16:41<08:12, 2.83s/it] /home/liangf/miniconda3/envs/torch/lib/python3.8/site-packages/PIL/TiffImagePlugin.py:793: UserWarning: Corrupt EXIF data. Expecting to read 4 bytes but only got 0.
warnings.warn(str(msg))
100%|██████████| 626/626 [23:09<00:00, 2.22s/it]
And the output precision is only
tensor(0.0396)
Does this issue relate to Corrupt EXIF data
?