I have a compute-bound data loading step, and was hoping to improve things by scaling up num_workers. Why can’t I increase throughput by parallelizing compute with multiprocessing (via num_workers > 0 in DataLoader)?
Setup
I have a script (below) to demonstrate the issue. High level setup:
-
IterableDataset
that generates a batch for each iteration. I’m parallelizing withnum_workers
inDataLoader
. I’m manually sharding within the dataset based on worker id. - Batch is randomly generated to rule out any file or network IO
- Extra (useless) computation is done when generating a batch to simulate compute bound data preparation and to ensure computation dominates any overhead due to passing data back from workers to main process
- I’m doing almost nothing with the batch in the main process, to simulate the extreme situation where data prep completely dominates actual training
- Running on a machine w/ 24 CPUs
- Pytorch 2.0.0
Full script
# data_test.py
import sys
import numpy as np
import torch
import torch.utils.data as data
from tqdm.auto import tqdm
class MyIterDataset(data.IterableDataset):
def __init__(self, num_batches, extra_work):
self.num_batches = num_batches
self.extra_work = extra_work
def __iter__(self):
worker_info = data.get_worker_info()
for ix in range(self.num_batches):
#manual sharding
if worker_info is None or ((ix + worker_info.id) % worker_info.num_workers == 0):
yield self.gen_batch(ix)
def gen_batch(self, ix):
#generate some arbitrary data
x = np.random.randn(1024) + ix
for _ in range(self.extra_work):
#computing things here rather than time.sleep
# to ensure this really is compute bound
np.dot(x,x)
x = torch.from_numpy(x)
return x
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"""Usage: python {__file__} <num_workers> <num_batches:50> <extra_work:100000>
num_workers arg to DataLoader
num_batches - number of bathes to produce
extra_work - how much extra computation to do to slow things down""")
exit(0)
num_workers = int(sys.argv[1])
num_batches = 50 if len(sys.argv) < 3 else int(sys.argv[2])
extra_work = 100000 if len(sys.argv) < 4 else int(sys.argv[3])
ds = MyIterDataset(num_batches, extra_work)
dl = data.DataLoader(ds, batch_size = None, num_workers=num_workers, collate_fn=None)
n=0
#minimal compute in main process
for df in tqdm(dl):
n += len(df)
print(f"got {n} samples, expected {1024 * num_batches}")
Observations
Total time to generate all batches increases as I add workers. I would expect 1 worker to be slightly slower than no multiprocessing because of comms overhead, but why no advantage to parallelizing w/ more than 1 worker?
Workers | Time |
---|---|
No multiprocessing | 0m8.989s |
1 worker | 0m10.037s |
2 workers | 0m10.596s |
4 workers | 0m11.036s |
8 workers | 0m12.752s |
The machine I’m running on has 24 cores, but from looking at top
, it seems that total compute for all processes (main + workers) only sums up to about 200%, regardless of how many workers are used. Why is this?
I would expect to be able to scale up to about 24 workers for this test script fairly efficiently. This makes me suspect I’m either doing something wrong or missing something obvious. It would be appreciated if someone could point my mistake out to me.
Other thoughts
I’ve tried playing with prefetch_factor
but it did not seem to help.
In the current script, the dataset is returning batches of torch Tensors and collate_fn
is set to None. I’ve also tried returning numpy arrays with and without a collate function. I’ve also tried calling share_memory_
before returning a batch from the dataset. None of theses seem to affect things.
To test whether communication/serialization of batch data is the culprit, I’ve tried increasing “extra work” by orders of magnitude so that amount of data being passed is tiny relative to computation. Total compute across all processes always is capped at about 200%, regardless of num_workers. This seems to be the smoking gun, but I don’t know what to make of it.