Multiprocess data loading speed becomes slow after importing torch

My actual problem:
I am training a tiny mlp network (~1M parameters) with lots of data (~5TB). The data is 2D matrices saved in hdf5 format with blosc compression. Each matrix is saved to a separate file and is around 25MB on disk (50MB after decompression). The matrices are intended to be passed to the network one by one, and no batching is needed (just shuffling for each epoch).
The hardware setup is 4090x4 ddp on a single node, and I find the training speed is severely bottlenecked by data loading speed. The gpus are running around 20% utilization each, and total disk read speed is around 4GB/s (the upper limit of my disk setup should be able to reach 20GB/s read speed, I reach > 10GB/s by directly cp the dataset).
I tried increasing num_workers for Dataloader, but the total speed does not change from num_workers=1 to num_workers=4 or num_workers=16. I find in htop that there seems to be a lock between processes: If I use num_workers=1, the worker process can run at 100% cpu ultilization; if num_workers=4, each worker can only run at 25%, and num_workers=16 each worker at 7%, rendering increasing workers useless.

Minimal example:
After some messing around, I managed to create a minimal example that a single import torch could create such a bottleneck (even if torch is never used!).

import os
import numpy as np
import multiprocessing
# import torch  # comment or uncomment this line

class CSDataset():  # simulate torch interface, this is not a subclass
    def __init__(self, dir):
        self.files = [f"{dir}/{file}" for file in os.listdir(file)]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        with open(self.file[idx], "rb") as f:

dset = CSDataset(data_dir)
indices = np.arange(len(dset))
ids = np.array_split(indices, 8)

def work(id):  # fake workload to read the data
    return len([dset[i] for i in id])

with multiprocessing.Pool(8) as p:
    print(, ids))

If import torch is commented out: runtime 16s, disk read speed >10GB/s, each process cpu at ~80%;
If import torch is kept: runtime 70s, disk read speed ~4GB/s, each process cpu at ~11%;

So what is torch doing and results in poor disk reading performance and inter-process lock?

torch version: 2.0.1+cu118

I’m curious what OS are you running on?

I run on Oracle Linux 7, you can think of it as centos 7. The torch environment is installed through conda.

Can you try adding multiprocessing.set_start_method('spawn') before with multiprocessing.Pool(8) as p:?

It would report:

RuntimeError: context has already been set

You might need to do something like this:

ctx = multiprocessing.get_context('spawn')
pool = multiprocessing.pool.Pool(context=ctx)

Changing multiprocessing to spawn or forkserver fixed the minimal example.
However still no luck with my actual problem even if I pass multiprocessing_context to Dataloader. The cpu utilization up a bit, but the training it/s does not improve much. I think probably it is hitting the bottleneck of python’s ipc and Queue? Maybe I have to implement my custom Dataloader in C++?

Can you provide a code snippet of your DataLoader and Dataset set up?

Its like this (almost the same as the example, just read actual hdf5 data):

class CSDataset(Dataset):
    def __init__(self, dir):
        self.files = [f"{dir}/{file}" for file in os.listdir(file)]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        with h5py.File(self.file[idx]) as f:
            sample = {"X": f["X"][:], "y":f["y"][:].reshape(-1, 1)}
            return sample

ctx = multiprocessing.get_context("forkserver")
dset = CSDataset(path_to_data)
dataloader = DataLoader(dset, batch_size=None, shuffle=True, pin_memeory=True, num_workers=4, multiprocessing_context=ctx, persistent_workers=True)