Use of Dataset class

I have a question about how can I use the Dataset class on my own data.
The problem is that I have my data in several hdf5 files (a medical imaging dataset). Each of these files, has several examples (each one with possibly a different number). If I want to create my own dataset, it says that I should provide the len and getitem methods. However, I dont really have a general index that identifies a sample, I can have an index inside each hdf5 file, but I am not sure how can I use all the data without loading all hdf5 files into memory?
Any suggestion in appreciated.

1 Like

I think the most simple (but not elegant) way is to calculate the total number of your samples by hands, and then map the index to certain hdf5 files. For example, I will use:

    class MergedDataset(data.Dataset):

        def __init__(self, hdf5_list):
            self.datasets = []
            self.total_count = 0
            for f in hdf5_list:
               h5_file = h5py.File(f, 'r')
               dataset = h5_file['YOUR DATASET NAME']
               self.total_count += len(dataset)

        def __getitem__(self, index):
            Suppose each hdf5 file has 10000 samples
            dataset_index = index % 10000
            in_dataset_index = int(index / 10000)
            return self.datasets[dataset_index][in_dataset_index]

        def __len__(self):
            return len(self.total_count)

I am not sure whether open certain hdf5 file will load the total data into memory or not. Wish this help to you.

1 Like

Thank you very much @cyyyyc123, this seems like a good solution. I would also like to know if this loads all the data into memory, or just a reference, and only when we index it loads it into memory. Does anybody know?

@cyyyyc123, I have implemented your idea adding a little since I don’t have the same number of samples in each hdf5. It looks like this:

class CT_dataset(data.Dataset):

  def __init__(self, path_patients):
      hdf5_list = [x for x in glob.glob(os.path.join(path_patients,'*.h5'))]#only h5 files
      print 'h5 list ',hdf5_list
      self.datasets = []
      self.total_count = 0
      for f in hdf5_list:
         h5_file = h5py.File(f, 'r')
         dataset = h5_file['data']
         dataset_gt = h5_file['label']
         self.total_count += len(dataset)
         #print 'len ',len(dataset)
      #print self.limits   

  def __getitem__(self, index):
      #print 'index ',index
      for i in xrange(len(self.limits)-1,-1,-1):
        #print 'i ',i
        if index>=self.limits[i]:
      #print 'dataset_index ',dataset_index
      assert dataset_index>=0, 'negative chunk'

      in_dataset_index = index-self.limits[dataset_index]

      return self.datasets[dataset_index][in_dataset_index], self.datasets_gt[dataset_index][in_dataset_index]

  def __len__(self):
      return self.total_count 

It works fine except when I I use numworkers>=2 in the DataLoader. Do you have any idea of what can be happening?

Yes! I also met a problem when I use numworkers>=2: if I re-index the mini-batch data which fetched by dataloader with numworkers>=2, the data will be in wrong order at times. When I set numworkers=1, everything works well again. I suppose it is because when use multi-workers (multi-processes), pytorch will not provide locks to synchronize the data. In most cases, it will be fine because we just use the data without modify them, but if we want to modify the fetched data, I guess we should use only one process to prevent modify the asynchronous data. I hope this will be helpful to you.

Hi Roger, I’m curious about the question you asked. Does this solution load all the data into memory?

Do you know how to solve the problem caused by num_workers >= 2

I do not find an example to use .h5 file in order to train a model. Could you help, please?

The problem can be solved by replacing hdf5 with zarr. It works effectively with num_workers >= 2 .