Hi,
I read here that we can use get()
method in Dataset
to load a single file before passing the Data
type object as a list to DataLoader
to be fed into model training. In my case, I want to load multiple files with multiprocessing instead of 1 file at a time to speed up the process. How can I do that? Thanks
To be clear, are you asking how to call get()
in multiple processes simultaneously and gather each of their resulting Data
objects in the main process? And do you want each subprocess to load multiple files? Here’s a simple example of using Pool
to load two files, data0.pt
and data1.pt
, with two subprocesses:
import torch
from torch.multiprocessing import Pool
def load(idx):
return torch.load(f'data{idx}.pt')
if __name__ == '__main__':
with Pool(2) as p:
data_list = p.map(load, [0,1])
but beware of the overhead of spawning subprocesses.
Hey @ArchieGertsman, thanks for the answer!
Basically, I want to do something like num_parallel_calls
(for multiprocessing) and .prefetch()
(to process the next batch of data in CPU while training is done in GPU) like in tf.data.Dataset here using torch geometric Dataset
. How can we do that? What I understand so far __getitem__
or get()
only loads 1 file at a time.