I tried to compile a version I used to cache my preprocessed images. This impl. saves some memory and gives some speedup since it avoids resizing the images in each process. About the torch.multiprocessing.set_sharing_strategy('file_system')
call I am not absolutely sure and maybe @ptrblck has some comments on this, but it was suggested by the PyTorch framework and without it I got an exception after 14+ hours of training when passing on the same dataloader to 30+ different model training loops.
import torch
# used at the beginning of your program
torch.multiprocessing.set_sharing_strategy('file_system')
import torchvision.transforms as transforms
from PIL import Image
from multiprocessing import Manager
from torch.utils.data import Dataset, Dataloader
class DatasetCache(object):
def __init__(self, manager, use_cache=True):
self.use_cache = use_cache
self.manager = manager
self._dict = manager.dict()
def is_cached(self, key):
if not self.use_cache:
return False
return str(key) in self._dict
def reset(self):
self._dict.clear()
def get(self, key):
if not self.use_cache:
raise AttributeError('Data caching is disabled and get funciton is unavailable! Check your config.')
return self._dict[str(key)]
def cache(self, key, img, lbl):
# only store if full data in memory is enabled
if not self.use_cache:
return
# only store if not already cached
if str(key) in self._dict:
return
self._dict[str(key)] = (img, lbl)
class CachedListDataset(Dataset):
'''Load image/labels from a list file.
The list file is like:
a.jpg label ...
'''
def __init__(self, root, list_file, cache, preprocess=None, transform=None):
'''
Args:
root: (str) ditectory to images.
list_file: (str/[str]) path to index file.
cache: (DatasetCache) shared object cache
preprocess: (function) preprocessing function before caching.
transform: (function) image transforms.
'''
self.root = root
self.cache = cache
self.preprocess = preprocess
self.transform = transform
self.fnames = []
self.labels = []
if isinstance(list_file, list):
# Cat multiple list files together.
# This is especially useful for voc07/voc12 combination.
tmp_file = '/tmp/listfile.txt'
os.system('cat %s > %s' % (' '.join(list_file), tmp_file))
list_file = tmp_file
with open(list_file) as f:
lines = f.readlines()
self.num_imgs = len(lines)
for line in lines:
splited = line.strip().split()
self.fnames.append(splited[0])
self.labels.append(np.array([int(splited[1])]))
def reset_memory(self):
self.cache.reset()
def __len__(self):
return self.num_imgs
def __getitem__(self, idx):
'''Load image.
Args:
idx: (int) image index.
Returns:
img: (tensor) image tensor.
labels: (tensor) class label targets.
'''
if self.cache.is_cached(idx):
img, labels = self.cache.get(idx)
else:
# Load image and boxes.
fname = self.fnames[idx]
img = None
with Image.open(os.path.join(self.root, fname)) as i:
if i.mode != 'RGB':
i = i.convert('RGB')
img = i
if self.preprocess:
img = self.preprocess(img)
labels = self.labels[idx]
labels = torch.from_numpy(labels)
self.cache.cache(idx, img, labels)
if self.transform:
img = self.transform(img)
return img, labels
# usage
manager = Manager()
cache = DatasetCache(manager)
ds = CachedListDataset(root='/path-to-data',
list_file='/path-to-list-file',
cache=cache,
preprocess=transforms.Compose([
transforms.Resize(256)
]),
transform=transforms.Compose([
transforms.RandomResizedCrop((224, 224), scale=[0.7, 1.0]),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.25, 0.25, 0.25),
transforms.RandomRotation(2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # in case of imagenet statistics
]))
dl = Dataloader(ds, ...)