I have a lot of images with .gif and .oct-stream extension. Since the ImageFolder will ignore those files, I use the DatasetFolder and provide my img_extension and loader as suggested by other forks on this forum. I create a dataloader and try to iterate through it. Unfortunately, it got stuck somewhere forever.
The folder structure is the following. I have attached all images except ‘01.octet-stream’ as this forum does not support it.
# debug/0/01.octet-stream
# debug/0/02.jpeg
# debug/0/03.gif
# debug/1/01.octet-stream
# debug/1/02.jpeg
# debug/1/03.gif
Here is my code, you can run them directly on notebook.
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import torch.nn.functional as F
import time
import os
import copy
from torchvision.datasets import ImageFolder
plt.ion()
data_transforms = transforms.Compose([transforms.Resize(300),
transforms.CenterCrop(299),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img_extensions = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.gif', '.octet-stream']
def my_loader(path):
from torchvision import get_image_backend
from PIL import Image
def my_pil_loader(path):
print ("loading {}".format(path))
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
if get_image_backend() == 'accimage':
print('{} uses accimage'.format(path))
try:
return accimage_loader(path)
except IOError:
print('{} accimage loading fail, using PIL'.format(path))
return my_pil_loader(path)
else:
print('{} uses PIL'.format(path))
return my_pil_loader(path)
my_loader('./debug/0/03.gif')
data_dir = './debug/'
batch_size = 32
image_datasets = datasets.DatasetFolder(data_dir, my_loader, img_extensions,
data_transforms)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=batch_size,
shuffle=True, num_workers=4)
dataset_sizes = len(image_datasets)
print(dataset_sizes)
Everything works fine so far. However, when I try to iterate through the dataloader and run the following code, the program got stuck forever! It seems the code runs into dead loop somewhere even before loading images as I do not see any print information during loading images. What’s wrong with implementation?
If I replace the DatasetFolder with ‘ImageFolder’ and get rid of the customized loader and extension, everything works fine. Very wired…
index = 0
for inputs, labels in dataloaders:
print(index)
print('inputs')
print(inputs.size())
print('labels')
print(labels.size())