import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class H5Dataset(Dataset):
def __init__(self, h5_path):
self.h5_path = h5_path
# self.h5_file = h5py.File(h5_path, 'r')
# self.length = len(h5py.File(h5_path, 'r'))
def __getitem__(self, index):
with h5py.File( self.h5_path,'r') as record:
data=record[str(index)]['data'].value
target=record[str(index)]['target'].value
return (data,target)
def __len__(self):
with h5py.File(self.h5_path,'r') as record:
return len(record)
# --
# Make data
f = h5py.File('test.h5')
for i in range(256):
f['%s/data' % i] = np.random.uniform(0, 1, (1024, 1024))
f['%s/target' % i] = np.random.choice(10)
# Runs correctly
# dataloader = torch.utils.data.DataLoader(
# H5Dataset('test.h5'),
# batch_size=32,
# num_workers=0,
# shuffle=False
# )
#
# count1=0
# for i, (data, target) in enumerate(dataloader):
# # print(data.shape)
# count1+=target
# print('count1 is equal to \n{}:'.format(count1))
# print(torch.sum(count1))
# if i > 10:
# break
# Throws error (sometimes, may have to restart python)
dataloader = torch.utils.data.DataLoader(
H5Dataset('test.h5'),
batch_size=32,
num_workers=24,
shuffle=False
)
count2=0
for i, (data, target) in enumerate(dataloader):
# print(data.shape)
# print(target.shape)
count2+=target
# if i > 10:
# break
print('count2 is equal to :\n{}'.format(count2))
print(torch.sum(count2))
If we modify like this, it can work well. But, I think in the __getitem__
, for obtainning each item, a h5py.File is carried out, which should be a expensive operation. I hope the so-called right answer can help you to debug further. Thank you in advance.