Define torch dataloader with h5py dataset

I wrote the following code and it acts really weird. sometimes the next(iter(dataloader)) works well and sometimes it throws an error. I shared the error below. Any thoughts?

def proc_images(data_dir ='flower-data', train = True):
"""
Saves compressed, resized images as HDF5 datsets
Returns
    data.h5, where each dataset is an image or class label
    e.g. X23,y23 = image and corresponding class label
"""

image_path_list = sorted([os.path.join(data_dir+ '/jpg', filename) for filename in os.listdir(data_dir + '/jpg') if not filename.startswith('.')])
label_dict = loadmat(data_dir + '/config/imagelabels.mat')
label_list = label_dict['labels'][0]
label_list[:] = [i - 1 for i in label_list]
ids = loadmat(data_dir + '/config/setid.mat')
if train:
    ids = np.concatenate((ids['trnid'], ids['valid']), 1) - 1
    name_data = 'flowers_train_data.h5'
    name_labels = 'flowers_train_labels.h5'
else:
    ids = ids['tstid'] - 1
    name_data = 'flowers_test_data.h5'
    name_labels = 'flowers_test_labels.h5'
    
label_list = [ label_list[index] for index in ids.astype(int).tolist()[0] ]
image_path_list = [ image_path_list[index] for index in ids.astype(int).tolist()[0] ]

with h5py.File(name_data, 'w') as hf: 
    for i,img in enumerate(image_path_list):            
        # Images
        image = Image.open(img)
        (HEIGHT, WIDTH) = image.size
        Xset = hf.create_dataset(
            name='X'+str(i),
            data=image)
        print("\r", i, end="")
with h5py.File(name_labels, 'w') as hf:
    for i,img in enumerate(image_path_list):
        yset = hf.create_dataset(
            name='y'+str(i),
            data = label_list[i],
            shape=(1,),
            maxshape=(None,))
    print("\r", i, end="")
    




class FlowersDataset(Dataset):

def __init__(self, train = True, transform=None):
    
    if not train:
        if os.path.isfile('flowers_test_data.h5'):
            print ("test data exist")
        else:
            proc_images(data_dir ='flower-data', train = False)
        
    if train:
        if os.path.isfile('flowers_train_data.h5'):
            print ("train data exist")
        else:
            proc_images(data_dir ='flower-data', train = True)
    
    self.transform = transform
    
    if train:
        self.hf_data = h5py.File('flowers_train_data.h5', 'r')
        self.hf_labels = h5py.File('flowers_train_labels.h5', 'r')
    else:
        self.hf_data = h5py.File('flowers_test_data.h5', 'r')
        self.hf_labels = h5py.File('flowers_test_labels.h5', 'r')

def __len__(self):
    return len(self.hf_data)

def __getitem__(self, idx):

    image = Image.fromarray(self.hf_data.get('X' + str(idx)).value)
    label = self.hf_labels.get('y' + str(idx)).value


    if self.transform:
        image = self.transform(image)

    return image, label


def read_data(batch_size, valid_size=0.01, num_workers = 2): 
transform_train = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.Resize(255),
                                       transforms.CenterCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

test_valid_transforms = transforms.Compose([transforms.Resize(255), 
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

trainset = FlowersDataset(train=True, transform=transform_train)
validset = FlowersDataset(train=True, transform=test_valid_transforms)
    
testset = FlowersDataset(train=False, transform=test_valid_transforms)

num_train = len(trainset)
indices = torch.randperm(num_train).tolist()
split = int(np.floor(valid_size * num_train))

    
train_idx, valid_idx = indices[split:], indices[:split]
#print(len(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
    
train_loader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(validset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=True)
    
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

return train_loader, valid_loader, test_loader

AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/anaconda3/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py”, line 178, in _worker_loop
data = fetcher.fetch(index)
File “/anaconda3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py”, line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/anaconda3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py”, line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/Users/abdolghaniebrahimi/Documents/neuron-compression/untitled10.py”, line 102, in getitem
image = Image.fromarray(self.hf_data.get(‘X’ + str(idx)).value)
AttributeError: ‘NoneType’ object has no attribute ‘value’

Could you print the idx inside __getitem__, use num_workers=0 and retry the code?
Based on the error message it seems some indices yield invalid data using your hdf5 file.
The last printed index should point to the failing operation.

I do not get any error when I set the num_workers = 0. However setting it to 2 gives me error. I am confused.

Have a look at this topic where we debugged some multiprocessing issues with HDF5.

I tried the following code with num_workers = 2 and got this result:

def __getitem__(self, idx):
    print(idx)
    print(self.hf_data.get('X' + str(idx)))
    image = Image.fromarray(np.array((self.hf_data.get('X' + str(idx)))))
    label = torch.LongTensor(np.array(self.hf_labels.get('y' + str(idx))))


    if self.transform:
        image = self.transform(image)

    return image, label

Result

1664
1110
None
<HDF5 dataset "X1110": shape (503, 500, 3), type "|u1">

Which tells me id 1664 is None. If I do:

print(hf_data.get('X' + str(1664)))
<HDF5 dataset "X1664": shape (500, 694, 3), type "|u1">

which is not None.
My HDF5 Version: 1.10.2 is according to the link, so not sure what is wrong.