How to use fast and slow workers together in DataLoader?

Hello everyone, I’m a newbie to pytorch.

There are two data simulation approaches in my training, one works fast, and one works much slower. Worker id is used to distinguish the two approaches, e.g. there are 10 workers in total, worker 0 to 8 use fast simulation, and worker 9 uses slow simulation.

Since multi-processes are used in DataLoader, it is supposed that the DataLoader and the training process works like the producer-consumer mode: once a data batch is produced by a worker, it is add to a queue. On the other side, the training process get data batches from the queue, and wait if the queue is empty.

However, it is found that the training time is the same as all workers use the slow approach. So, I deduce that workers are not owned by independent subprocesses, but run in a loop, the data simulation is slowed down by the worker 9.

My DataLoader is initialized like this:

dataloader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, num_workers = NUM_WORKERS)

My question is that is there are any way to make the process not stucked by the slow workers? Is there are any parameters I missing? Thank you very much!

For example, I have tried this toy dataset:

import numpy as np
import torch
import time


class ToyDataset(torch.utils.data.IterableDataset):
    def __init__(self, numworkers):
        super().__init__()
        self.numworkers = numworkers
        self.batchidx = 0
    
    
    def __iter__(self):
        return self
    
    
    def __next__(self):
        id = 0
        info = torch.utils.data.get_worker_info()
        if info is not None:
            id = info.id
        
        if id < self.numworkers - 1:
            time.sleep(0.5)
            print('fast worker {:d}, batch {:d}'.format(id, self.batchidx))
        else:
            time.sleep(5)
            print('-slow worker {:d}, batch {:d}'.format(id, self.batchidx))
        
        
        retval = self.batchidx
        self.batchidx += 1
        return retval


BATCH_SIZE = 1
NUM_WORKERS = 2
TIMEOUT = 0
PREFETCH_FACTOR = 2


dataset = ToyDataset(NUM_WORKERS)
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size = BATCH_SIZE, 
    num_workers = NUM_WORKERS, 
    timeout = TIMEOUT, 
    prefetch_factor = PREFETCH_FACTOR)
it = iter(dataloader)

for i in range(1000):
    try:
        data = next(it)
        print('----------')
    except RuntimeError:
        print('timeout')

The output is something like this:
fast worker 0, batch 5
-slow worker 1, batch 3


fast worker 0, batch 6
-slow worker 1, batch 4


fast worker 0, batch 7
-slow worker 1, batch 5


The data generation is always slowed down by the slow worker.