import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class H5Dataset(Dataset):
def __init__(self, h5_path):
self.h5_path = h5_path
self.h5_file = h5py.File(h5_path, 'r')
self.length = len(h5py.File(h5_path, 'r'))
def __getitem__(self, index):
record = self.h5_file[str(index)]
return (
record['data'].value,
record['target'].value,
)
def __len__(self):
return self.length
# --
# Make data
f = h5py.File('test.h5')
for i in range(256):
f['%s/data' % i] = np.random.uniform(0, 1, (1024, 1024))
f['%s/target' % i] = np.random.choice(1000)
# Runs correctly
dataloader = torch.utils.data.DataLoader(
H5Dataset('test.h5'),
batch_size=32,
num_workers=0,
shuffle=True
)
for i,(data,target) in enumerate(dataloader):
print(data.shape)
if i > 10:
break
# Throws error (sometimes, may have to restart python)
dataloader = torch.utils.data.DataLoader(
H5Dataset('test.h5'),
batch_size=32,
num_workers=8,
shuffle=True
)
for i,(data,target) in enumerate(dataloader):
print(data.shape)
if i > 10:
break
# KeyError: 'Traceback (most recent call last):
# File "/home/bjohnson/.anaconda/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop
# samples = collate_fn([dataset[i] for i in batch_indices])
# File "<stdin>", line 11, in __getitem__
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
# File "/home/bjohnson/.anaconda/lib/python2.7/site-packages/h5py/_hl/group.py", line 167, in __getitem__
# oid = h5o.open(self.id, self._e(name), lapl=self._lapl)
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
# File "h5py/h5o.pyx", line 190, in h5py.h5o.open
# KeyError: Unable to open object (bad object header version number)
The formated nicely code can be find here: