Problem Testing Image classification predictions with torchvision.datasets.ImageFolder vs torchvision.transforms.Compose on Resnet50 and my custom architecture

When I training my custom model and fine-tuning the pretrained resnet50 with using torchvision.datasets.ImageFolder, I have underfitting/overfitting during training and get good 98% accuracy on validation. If I use torchvision.datasets.ImageFolder to load the test folder, i get very good prediction and confusion matrix but If I reuse this transform to convert image when I scan the test folder , I always get poor prediction. I don’t know what is this problem in my cas.

This is confusion matrix with torchvision.datasets.ImageFolder


model = models.resnet50(pretrained = False)
model.fc = nn.Sequential(nn.Linear(2048, 512),
nn.Linear(512, 20))
model=model.to(device)
checkpoint = torch.load(os.path.join(w_dir,‘checkpoints’,test_name,‘last.pth’), map_location=device)
model.load_state_dict(checkpoint[‘model_state_dict’])
test_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0, 0, 0],
std=[1, 1, 1])])
dataset_name=‘datasets3*’
test_path=os.path.join(w_dir,‘dataset_cropped_20cl’,
dataset_name,‘test’)
test_dataset = datasets.ImageFolder(
root=test_path,
transform=test_transform)
test_loader = DataLoader(
test_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=4, pin_memory=True)
model.eval()
print(‘Make prediction’)
pred_list=list()
true_list=list()
true_name_list=list()
pred_name_list=list()
save_pred_path=os.path.join(w_dir,‘CNN_prediction’,‘checkpoints’)
if os.path.exists(save_pred_path):
shutil.rmtree(save_pred_path)
os.makedirs(save_pred_path)
with torch.no_grad():
for i, data in tqdm(enumerate(test_loader), total=len(test_loader)):
image, labels = data
image = image.to(device)
labels = labels.to(device)
outputs = model(image)
#output_label = torch.topk(outputs, 1)
#pred_class = labels[int(output_label.indices.item())]
#pred_list.append(int(output_label.indices))
#true_list.append(int(labels))
_, preds = torch.max(outputs.data, 1)
pred_list.append(int(preds))
true_list.append(int(labels))
true_name_list.append(id_to_labels[int(labels)])
pred_name_list.append(id_to_labels[int(preds)])
if int(preds)==int(labels):
print(‘prediction is good’)
print(f’ true class is {int(labels)}, predicted class is {int(preds)}‘)
print(f"GT: {id_to_labels[int(labels)]}, pred: {id_to_labels[int(preds)]}")
else:
print(‘prediction is not good’)
print(f’ true class is {int(labels)}, predicted class is {int(preds)}')
print(f"GT: {id_to_labels[int(labels)]}, pred: {id_to_labels[int(preds)]}")
This is ccconfusion matrix with using my transform

test_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0, 0, 0],
std=[1, 1, 1])])
dataset_name=‘datasets3*’
dataset_path=os.path.join(w_dir,‘dataset_cropped_20cl’,dataset_name,‘test’)
pred_list=list()
true_list=list()
true_name_list=list()
pred_name_list=list()
model.eval()
with torch.no_grad():
#scan test folder
for folder_name in os.listdir(dataset_path):
gt_class=folder_name
for roots,dirs,files in os.walk(os.path.join(dataset_path,folder_name)):
for file in files:
img_path=os.path.join(roots,file)
image = cv2.imread(img_path)
orig_image = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image=test_transform(image)
image = torch.unsqueeze(image, 0)
with torch.no_grad():
outputs = model(image.to(device))
_, pred_class = torch.max(outputs.data, 1)
pred_list.append(int(pred_class))
true_list.append(int(gt_class))
true_name_list.append(id_to_labels[int(gt_class)])
pred_name_list.append(id_to_labels[int(pred_class)])

Thanks for your helps

Your code unfortunately not properly formatted so I’m unsure what kind of transformation you are using in which run, but note that:

transforms.ToTensor(),
transforms.Normalize(
    mean=[0, 0, 0],
    std=[1, 1, 1])])

will not normalize the tensors since zeros will be subtracted from them and afterwards you are dividing by ones. In case you used another Normalize transformation during training you should also reuse it during testing.

1 Like
  • Yes, I also used the same transformation for test as transformation during training.
train_transform = transforms.Compose([
                                transforms.Resize((img_size, img_size)),
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomVerticalFlip(p=0.5),
                                transforms.RandomAffine(180,scale=(0.25,2)),
                                transforms.RandomPerspective(p=0.5),
                                transforms.ColorJitter(brightness=0.5,hue=0.2),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                                    mean=[0.5, 0.5, 0.5],
                                                    std=[0.5, 0.5, 0.5]),
                                ])

valid_transform = transforms.Compose([
                                    transforms.Resize((img_size, img_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                                        mean=[0.5, 0.5, 0.5],
                                                        std=[0.5, 0.5, 0.5]),
                                    ])
train_path=os.path.join(w_dir,'dataset_cropped_20cl',
                        dataset_name,'train')
val_path=os.path.join(w_dir,'dataset_cropped_20cl',
                      dataset_name,'val')


train_dataset = datasets.ImageFolder(root=train_path,
                                    transform=train_transform)
valid_dataset = datasets.ImageFolder(root=val_path,
                                     transform=valid_transform)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,
                          shuffle=True,num_workers=4, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, 
                          shuffle=False,num_workers=4, pin_memory=True)
  • Reload the trained model
 #recall model
if model_name=='custom':
    model=CNN_network(img_size=img_size,
                      in_channels=3,
                      conv_depths=[8,16,32,64,128,256],
                      lin_depths=[1024,512],
                      dropout_conv=False,
                      dropout_li=False,
                      nb_classes=nbClasses,
                      classifer_activation=None,
                      activation='relu').to(device)
elif model_name=='resnet50':
    model = models.resnet50(pretrained = False)
    model.fc = nn.Sequential(nn.Linear(2048, 512),
                             nn.Linear(512, 20))
    model=model.to(device)
  
#reload model
checkpoint = torch.load(os.path.join(w_dir,'checkpoints',
                                     test_name,'last.pth'), 
                                     map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
  • Using ImageFolder to test model, i got the good confusion matrix
test_transform = transforms.Compose([
                                    transforms.Resize((img_size, img_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                        mean=[0.5, 0.5, 0.5],
                                        std=[0.5, 0.5, 0.5])])
 dataset_name='datasets3*'
 test_path=os.path.join(w_dir,'dataset_cropped_20cl',
                        dataset_name,'test')
 test_dataset = datasets.ImageFolder(root=test_path,
                                    transform=test_transform)
 test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                         shuffle=False,num_workers=4, pin_memory=True)
 model.eval()
 with torch.no_grad():
      for i, data in tqdm(enumerate(test_loader), total=len(test_loader)):
          image, labels = data
          image = image.to(device)
          labels = labels.to(device)
          
          #get prediction
          outputs = model(image)
          _, preds = torch.max(outputs.data, 1)

  • Without using ImageFolder, I scan my the test folder. In this cas, I used cv2 to load images and convert them to transformation. I got the poor confusion matrix
 test_transform = transforms.Compose([
                                    transforms.ToPILImage(),
                                    transforms.Resize((img_size, img_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                        mean=[0.5, 0.5, 0.5],
                                        std=[0.5, 0.5, 0.5]) ])
 dataset_name='datasets3*'
 dataset_path=os.path.join(w_dir,'dataset_cropped_20cl',
                                  dataset_name,'test')
  model.eval()
  with torch.no_grad():
            for folder_name in os.listdir(dataset_path):
                gt_class=folder_name
                for roots,dirs,files in os.walk(
                                   os.path.join(dataset_path,folder_name)):
                    for file in files:
                        #load and convert image to true format
                        img_path=os.path.join(roots,file)
                        image = cv2.imread(img_path)
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                        image=test_transform(image)
                        image = torch.unsqueeze(image, 0).float()
                        
                        #get predictions
                        outputs = model(image.to(device))
                        _, pred_class = torch.max(outputs.data, 1)
  • I also try PIL.Image to load images but i got the same result.
test_transform = transforms.Compose([
                                    transforms.Resize((img_size, img_size)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                        mean=[0.5, 0.5, 0.5],
                                        std=[0.5, 0.5, 0.5]) ])
model.eval()
with torch.no_grad():
    for folder_name in os.listdir(dataset_path):
        gt_class=folder_name
        for roots,dirs,files in os.walk(
                           os.path.join(dataset_path,folder_name)):
            for file in files:
                 img_path=os.path.join(roots,file)
                 image=Image.open(img_path)
                 image=test_transform(image)
                 image = torch.unsqueeze(image, 0).float()
                 #get predictions
                 outputs = model(image.to(device))
                 _, pred_class = torch.max(outputs.data, 1)

How are you defining the targets in your manual approach and did you compare the class indices to what ImageFolder returns?

  • I was wrong when I was structuring the dataset folder. I have 20 classes and I marked the class folder as a number, I thought that the ImageFolder returned the correct target number but when I checked the dataset.class_to_idx, I got wrong label_to_number conversion.
train_dataset.class_to_idx={'0': 0, '1': 1, '10': 2, '11': 3, '12': 4, '13': 5, '14': 6, '15': 7, '16': 8, '17': 9, '18': 10, '19': 11, '2': 12, '3': 13, '4': 14, '5': 15, '6': 16, '7': 17, '8': 18, '9': 19}

I think I don’t need to label the class folder name as a number! My problem is resolved by modify the
prediction number_to_class indices when I make prediction and confusion matrix.

But I have other question:

  • Why we don’t need to convert the targets to [0,0,1,0,…,0] instead of torch.tensor([3]) when using CrossEntropyLoss?

  • When I used softmax in the last linear layer of my custom CNN network or I added in the last layer of Resnet50, during the training, the model didn’t achieve as good accuracy as without using torch.nn.Sotfmax for both the cas.

To get the desired logit from the model output you could either multiply the logits with a one-hot encoded target tensor or you could directly index the logits creating the same results.
Storing the target with a single integer instead of a large one-hot encoded tensor containing mostly zeros sounds reasonable.
However, note that nn.CrossEntropyLoss accepts also “soft” targets (and thus also one-hot encoded targets) in newer versions if you pass the target as a floating point type.

That’s expected as nn.CrossEntropyLoss expects raw logits as the model outputs since internally F.log_softmax will be applied.