Eager loading of dataset files

Dear all,

I apologize if this is a simple question, but I haven’t found a simple way to do the following. I have a quite large dataset (100 Gb) that I cannot fit in the memory. I have then built a IterableDataset as follows:

class datasets( IterableDataset ):
    def __init__( self, path, device, variables, labels):
    def __iter__(self):
        for f in self.files:
            print(f"Reading from {f}")
            thedata=pd.read_hdf(f, 'df')
            labels=torch.Tensor( thedata[self.labels].values).to(self.device)
            variables=torch.Tensor( thedata[self.variables].values).to(self.device)
            yield from zip(labels, variables)

This seems to work well, but it’s not particularly efficient, as loading each file takes several seconds, during which the training is stopped. I was wondering if there would be some way of doing an eager prefetching of the files, as

  1. Loads file 1
  2. Trains in file 1, while it loads file 2
  3. Trains in file 2, while it loads file 3
    Thanks in advance!

just in case it is useful for others, I got away doing something like the code below. Note that this is a rough implementation, the actual code is a bit more complex than this.

This seems to be working quite efficiently. There’s probably some more optimization to be done as the prefetching and the training seem to be competing for the same resources (the training is slower while the prefetching is on-going), but this may be due to the limitations of my hardware.

def _load_next( file_list, index, device, labels, variables):
    if index >= len(file_list):
	return None
    thedata=pd.read_hdf(f, 'df')
    labels=torch.Tensor( thedata[labels].values).to(device).__iter__() # creating the iter takes time, doing asynchronously
    variables=torch.Tensor( thedata[variables].values).to(device).__iter__() 
    return index, (labels,variables)

class datasets( IterableDataset ):
    def __init__( self, path, device):
        self.executor = ThreadPoolExecutor(max_workers=1)
    def restart(self):
        print("Re-starting iterator")
	self.file_index, self.current_data=_load_next(self.files,0, self.device)
	self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device)   # prefetch the following                                                                

    def __iter__(self):
        while True:
              yield from  zip(self.current_data[0], self.current_data[1])
        if result is None: # theres nothing to be prefetched, stop here                                                                                                                  
       else: # we are at the end of a file, and we have prefetched the following one. we will update were we are yielding from                                                          
		self.file_index, self.current_data = result
	        self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device)   # prefetch the following