I have an iterable dataset and let’s say I have 2 workers: worker 0 generates an audio clip and combines it with other pre-generated audio clips to create a data point, and workers 1 only combines pre-generated audio clips to create a data point. Naturally, worker 0 takes more time then worker 1 (audio clip generation and saving to the disk takes ≈ 0.8s while loading 12 audio clips takes 0.18s).
The workers obviously work in parallel so I expect the average loading time to be equal to 0.18 seconds at most (while worker 0 is generating the audio clip, worker 1 will take on the task of generating data points) that is not the case however, since I’m getting avg loading time of about 0.52s (almost thrice as slow). How can I solve this issue? is that even possible?
here is some pseudo code:
class GeneratedDataset(IterableDataset): def __init__(self, note_num): super(GeneratedDataset).__init__() self.note_num = note_num def __iter__(self): worker = torch.utils.data.get_worker_info() if worker: id = worker.id np.random.seed(int(2**31*torch.rand(1))) else: id = 0 while True: if id==0: # this is the code that only runs on worker 0 # GENERATE AUDIO CLIP # GENERATE RESPECTIVE LABELS # SAVE BOTH ------------------------ # LOADING AUDIO CLIP # GENERATING SPECTROGRAM # TENSOR MANIPULATION (padding, normalizing, etc.) # SAVING THE TENSOR else: # LOAD RANDOM (pre-saved) TENSOR # LOAD RESPECTIVE LABELS for _ in range(12): # MORE TENSOR MANIPULATION # PROCESSING THE LABELS # CHOOSE ANOTHER (TENSOR, LABELS) PAIR FOR NEXT LOOP # CONCATENATE THE LOADED TENSORS. SAME FOR THE LABELS yield # TENSOR, LABELS
Any help would be greatly appreciated.