Below is my custom dataloader that inherits from DatasetFolder (its exactly the same except for the def__getitem__
)
from torch.utils.data import Dataset
from PIL import Image
import os
import os.path
import sys
import torch
import numpy as np
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename):
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
class DatasetFolder(Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, loader, extensions=None, transform=None,
target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
# super().__init__(root, transform=transform,
# target_transform=target_transform)
#super(DatasetFolder, self).__init__(root, transform=transform)
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def __getitem__(self, index):
frames = []
framespath = []
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
frame = torch.from_numpy(np.load(path))
frames.append(frame)
framespath.append(path)
# sample = self.loader(path)
# if self.transform is not None:
# sample = self.transform(sample)
# if self.target_transform is not None:
# target = self.target_transform(target)
return frame, path
def __len__(self):
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid_file (used to check of corrupt files)
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, is_valid_file=None):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
Then, this is my main code, its actually not for training, its a clustering task:
import torch.nn as nn
import torch
from dataset2 import DatasetFolder
from torch_kmeans import KMeans
from torchvision import transforms
import argparse
import matplotlib.cm
cmap = matplotlib.cm.get_cmap('Reds')
if torch.cuda.is_available():
device = torch.device("cuda")
print("working on gpu")
else:
device = torch.device("cpu")
print("working on cpu")
parser = argparse.ArgumentParser(description='K-means')
parser.add_argument('--batch-size', default=1,
type=int, help='mini-batch size')
parser.add_argument('--workers', default=0, type=int,
help='number of data loading workers')
args = parser.parse_args()
def validate(val_loader):
model = KMeans(n_clusters=3)
model = nn.DataParallel(model)
model.to(device)
# switch to evaluate mode
model.eval()
for i, (inputs, _, path) in enumerate(val_loader):
input_var = [input.cuda() for input in inputs]
# compute output
with torch.no_grad():
input_tensors = torch.stack(input_var)
input_tensors = torch.as_tensor(input_tensors)
# collapse in first dimension
torch.unsqueeze(input_tensors, 1)
input_tensors = input_tensors.permute(1, 0, 2)
output = model(input_tensors)
print(output)
if __name__ == '__main__':
valdir = r'C:\\\\'
transform = (transforms.Compose([
transforms.ToTensor()
]
),
transforms.Compose([
transforms.ToTensor()]
)
)
val_dataset = DatasetFolder(valdir, transform)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
num_workers=args.workers, pin_memory=True)
print("--------------------------------------------------Validation--------------------------------------------------")
input_tensors, output_labels, centroids = validate(val_loader)
Thank you very much!