Fetching a new iterator takes a long time with DistributedDataParallel but not if running normally on a single GPU

I am trying to adapt some existing training code to run on multiple GPUs on the same node using DistributedDataParallel. I have a main_worker_function which sets up the dataset, model, etc then runs training and validation. I run either a single version of this to run on a single GPU, or call it using mp.spawn to run with DDP:

if is_distributed:
    print("Using multiprocessing...")
    devids = ["cuda:{0}".format(x) for x in config.gpu_list]
    print("Using DistributedDataParallel on these devices: {}".format(devids))
    mp.spawn(main_worker_function, nprocs=ngpus, args=(ngpus, is_distributed, config))
else:
    print("Only one gpu found, not using multiprocessing...")
    main_worker_function(0, ngpus, is_distributed, config)

Currently it seems like I obtain no increase in running speed when running on two GPUS with DDP, and upon closer inspection it seems to be because fetching a new validation iterator takes an enormous amount of time. To check this I replaced my training function with:

time0 = time()
print("Fetching iterator")
val_iter = iter(self.data_loaders["validation"])
time1 = time()
print("Done in time: ", time1 - time0)

val_iter = iter(self.data_loaders["validation"])
time2 = time()
print("Done in time: ", time2 - time1)

Running the above code on a single GPU and with num_workers = 4 yields outputs:

Fetching iterator
Done in time:  0.40811872482299805
Done in time:  0.7533595561981201

Running the same code on a single GPU and num_workers=4, but spawned as single subprocess with mp.spawn yields:

Fetching iterator
Done in time:  49.32914686203003
Done in time:  50.69797325134277

As far as I can tell it has something to do with the spawning of the worker functions, because spawning a subprocess with num_workers=2 yields:

Fetching iterator
Done in time:  25.56034803390503
Done in time:  28.712769746780396

But I’m not sure what behaviour is changing between single and multiprocessing mode to cause this increase in time, and I’m not quite sure how to check (I’ve been consulting https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading but I’m not sure if it specifies any difference between default and multiprocessing worker spawning). Note that the initialization of the dataset and sampler should currently be performed the same (I was previously using a DistributedSampler in DDP mode, but this problem persists regardless of or not this sampler is used).

@Whisky-Jack Can you share the code you’re using to initialize the DataLoader on each process?

cc @VitalyFedyunin

Note that I’ve made some important progress on this since last week, and so I included that first, with the dataloader information below (you can jump to that part if that’s more important).

I have done more investigation on this since last week, and I have a bit of a better handle on where the problem is occurring. The problem seems to involve differences in spawning the worker functions with different multiprocessing start methods. This is described a little bit in the platform specific behaviours section found in here torch.utils.data — PyTorch 2.1 documentation, although it doesn’t really discuss differences between running in a single process vs processes launched with spawn.

Calling print(mp.get_context()) in the main_worker_function in single GPU mode yields:

<multiprocessing.context.ForkContext object at 0x7fc14dd64da0>

While in the spawned process

<multiprocessing.context.SpawnContext object at 0x7f8e02fd0ef0>

Consequently, I think the workers processes are being spawned using fork() in the single process case and spawn() in the multiprocessing case.

The documentation states:

  • On Unix, fork() is the default multiprocessing start method. Using fork(), child workers typically can access the dataset and Python argument functions directly through the cloned address space.
  • On Windows, spawn() is the default multiprocessing start method. Using spawn(), another interpreter is launched which runs your main script, followed by the internal worker function that receives the dataset, collate_fn and other arguments through pickle serialization.

Since I’m running on linux, I’m guessing that with fork, the cloned address space is working correctly with low overhead, while in the case of the spawn something about pickling the dataset is causing a big overhead when trying to pass it to the workers.

To check that this difference is causing the issue, I tried running in single GPU mode (so NOT spawning any subprocesses using mp.spawn, or wrapping with DistributedDataParallel), but changing the default start method using mp.set_start_method('spawn') in my main function. This results in the same context:

<multiprocessing.context.SpawnContext object at 0x7f5e3b4a29e8>

And produces the same lag as in the multiprocessing case:

Fetching iterator
Done in time:  52.480613708496094
Fetching iterator
Done in time:  52.47240161895752
Fetching iterator
Done in time:  53.32112169265747

I then took a look at my dataset to see what might be causing problems. Our dataset is stored in h5 form, and is loaded into numpy arrays when the dataset is initialized.

The dataset consists of a number of types of data, some of which are too large to load into memory, and so are stored in memmaps (these memmaps are not initialized until the first call to get_item):

with h5py.File(self.h5_path, 'r') as h5_file:
    self.dataset_length = h5_file["labels"].shape[0]

    hdf5_hit_pmt    = h5_file["hit_pmt"]
    hdf5_hit_time   = h5_file["hit_time"]
    hdf5_hit_charge = h5_file["hit_charge"]

    # initialize memmap param dict
    self.pmt_dict    = {'shape':hdf5_hit_pmt.shape,    'offset':hdf5_hit_pmt.id.get_offset(),   'dtype':hdf5_hit_pmt.dtype}
    self.time_dict   = {'shape':hdf5_hit_time.shape,   'offset':hdf5_hit_time.id.get_offset(),  'dtype':hdf5_hit_time.dtype}
    self.charge_dict = {'shape':hdf5_hit_charge.shape, 'offset':hdf5_hit_charge.id.get_offset(),'dtype':hdf5_hit_charge.dtype}

The other data is small enough to be loaded into memory, and is stored as attributes of the dataset:

self.labels    = np.array(h5_file["labels"])
self.energies  = np.array(h5_file["energies"])
self.positions = np.array(h5_file["positions"])
self.angles    = np.array(h5_file["angles"])

This latter data seems to be what is causing the issue. If none of this data is stored as part of the dataset, very little lag is observed:

Fetching iterator
Done in time:  1.6751770973205566
Fetching iterator
Done in time:  2.464519500732422
Fetching iterator
Done in time:  2.2533977031707764

Whereas if only the labels is stored, the lag increases:

Fetching iterator
Done in time:  9.610902309417725
Fetching iterator
Done in time:  10.206432342529297
Fetching iterator
Done in time:  9.656260013580322

With more parameters producing more lag.

I then tried moving this initialization step to the same initialization function in get_item that creates the memmaps (so this information is initialized in the worker functions after they have spawned), and this seems to eliminate most of the lag:

Fetching iterator
Done in time:  1.7485809326171875
Fetching iterator
Done in time:  8.036261796951294
Fetching iterator
Done in time:  3.001110076904297

It does add overhead to the first call to get_item, but not the 50s lag observed previously. To see this I ran:

print("Fetching iterator...")
val_iter = iter(self.data_loaders["validation"])
time1 = time()
print("Done in time: ", time1 - time0)

print("Fetching data...")
data = next(val_iter)
time0 = time()
print("Done in time: ", time0 - time1)

print("Fetching data...")
data = next(val_iter)
time1 = time()
print("Done in time: ", time1 - time0)

Yielding:

Fetching iterator...
Done in time:  1.7992744445800781
Fetching data...
Done in time:  6.465835094451904
Fetching data...
Done in time:  0.10752415657043457

Which is not great, but probably tolerable. It would be nice to know what is going wrong with storing this data in the __init__/passing it to the workers if anyone has an idea why this lag is being produced.

DataLoader

Regarding the instantiation of the dataloader, the process is a little convoluted because we’re running using hydra configs and it is made in a series of calls, but I have tried to pare it down into a reasonable example.

We have an engine class which handles training and so has the dataloaders as attributes. In our main_worker_function, we set this up using the config information using:

# Configure data loaders
for task, task_config in config.tasks.items():
    if 'data_loaders' in task_config:
        engine.configure_data_loaders(config.data, task_config.data_loaders, is_distributed, config.seed)
def configure_data_loaders(self, data_config, loaders_config, is_distributed, seed):
        """
        Set up data loaders from loaders config
        """
        for name, loader_config in loaders_config.items():
            self.data_loaders[name] = get_data_loader(**data_config, **loader_config, is_distributed=is_distributed, seed=seed)

We use the standard DataLoader class, passing a SubsetRandomSampler on a set of indices and our own dataset.

from hydra.utils import instantiate
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler

def get_data_loader(dataset, batch_size, sampler, num_workers, is_distributed, seed, split_path=None, split_key=None, transforms=None):
    split_indices = np.load(split_path, allow_pickle=True)[split_key]
    
    sampler = SubsetRandomSampler(split_indices)
    dataset = instantiate(dataset, transforms=transforms, is_distributed=is_distributed)
    return DataLoader(dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)

Hydra instantiate simply returns the object specified in the dataset (which comes from the config). I’ve checked and the dataset at this point in the code is just our dataset object.

For details on the dataset, refer to the code in the section above this one. I can clarify details further if this is too complicated.