My actual problem:
I am training a tiny mlp network (~1M parameters) with lots of data (~5TB). The data is 2D matrices saved in hdf5 format with blosc compression. Each matrix is saved to a separate file and is around 25MB on disk (50MB after decompression). The matrices are intended to be passed to the network one by one, and no batching is needed (just shuffling for each epoch).
The hardware setup is 4090x4 ddp on a single node, and I find the training speed is severely bottlenecked by data loading speed. The gpus are running around 20% utilization each, and total disk read speed is around 4GB/s (the upper limit of my disk setup should be able to reach 20GB/s read speed, I reach > 10GB/s by directly cp
the dataset).
I tried increasing num_workers
for Dataloader
, but the total speed does not change from num_workers=1
to num_workers=4
or num_workers=16
. I find in htop
that there seems to be a lock between processes: If I use num_workers=1
, the worker process can run at 100% cpu ultilization; if num_workers=4
, each worker can only run at 25%, and num_workers=16
each worker at 7%, rendering increasing workers useless.
Minimal example:
After some messing around, I managed to create a minimal example that a single import torch
could create such a bottleneck (even if torch
is never used!).
import os
import numpy as np
import multiprocessing
# import torch # comment or uncomment this line
class CSDataset(): # simulate torch interface, this is not a subclass
def __init__(self, dir):
self.files = [f"{dir}/{file}" for file in os.listdir(file)]
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
with open(self.file[idx], "rb") as f:
return f.read()
dset = CSDataset(data_dir)
indices = np.arange(len(dset))
ids = np.array_split(indices, 8)
def work(id): # fake workload to read the data
return len([dset[i] for i in id])
with multiprocessing.Pool(8) as p:
print(p.map(work, ids))
If import torch
is commented out: runtime 16s, disk read speed >10GB/s, each process cpu at ~80%;
If import torch
is kept: runtime 70s, disk read speed ~4GB/s, each process cpu at ~11%;
So what is torch
doing and results in poor disk reading performance and inter-process lock?
torch version: 2.0.1+cu118