Dear all,
I apologize if this is a simple question, but I haven’t found a simple way to do the following. I have a quite large dataset (100 Gb) that I cannot fit in the memory. I have then built a IterableDataset
as follows:
class datasets( IterableDataset ):
def __init__( self, path, device, variables, labels):
self.files=glob.glob(path)
self.device=device
self.variables=variables
self.labels=labels
def __iter__(self):
for f in self.files:
print(f"Reading from {f}")
thedata=pd.read_hdf(f, 'df')
labels=torch.Tensor( thedata[self.labels].values).to(self.device)
variables=torch.Tensor( thedata[self.variables].values).to(self.device)
yield from zip(labels, variables)
This seems to work well, but it’s not particularly efficient, as loading each file takes several seconds, during which the training is stopped. I was wondering if there would be some way of doing an eager prefetching of the files, as
- Loads file 1
- Trains in file 1, while it loads file 2
- Trains in file 2, while it loads file 3
Thanks in advance!
Hi,
just in case it is useful for others, I got away doing something like the code below. Note that this is a rough implementation, the actual code is a bit more complex than this.
This seems to be working quite efficiently. There’s probably some more optimization to be done as the prefetching and the training seem to be competing for the same resources (the training is slower while the prefetching is on-going), but this may be due to the limitations of my hardware.
def _load_next( file_list, index, device, labels, variables):
if index >= len(file_list):
return None
thedata=pd.read_hdf(f, 'df')
labels=torch.Tensor( thedata[labels].values).to(device).__iter__() # creating the iter takes time, doing asynchronously
variables=torch.Tensor( thedata[variables].values).to(device).__iter__()
return index, (labels,variables)
class datasets( IterableDataset ):
def __init__( self, path, device):
self.files=glob.glob(path)
self.device=device
self.executor = ThreadPoolExecutor(max_workers=1)
self.restart()
def restart(self):
print("Re-starting iterator")
self.file_index, self.current_data=_load_next(self.files,0, self.device)
self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device) # prefetch the following
def __iter__(self):
while True:
yield from zip(self.current_data[0], self.current_data[1])
if result is None: # theres nothing to be prefetched, stop here
self.executor.shutdown(wait=False)
self.restart()
break
else: # we are at the end of a file, and we have prefetched the following one. we will update were we are yielding from
self.file_index, self.current_data = result
self.prefetch=self.executor.submit(_load_next, self.files, self.file_index+1, self.device) # prefetch the following