Note that I’ve made some important progress on this since last week, and so I included that first, with the dataloader information below (you can jump to that part if that’s more important).
I have done more investigation on this since last week, and I have a bit of a better handle on where the problem is occurring. The problem seems to involve differences in spawning the worker functions with different multiprocessing start methods. This is described a little bit in the platform specific behaviours section found in here torch.utils.data — PyTorch 2.1 documentation, although it doesn’t really discuss differences between running in a single process vs processes launched with spawn
.
Calling print(mp.get_context())
in the main_worker_function
in single GPU mode yields:
<multiprocessing.context.ForkContext object at 0x7fc14dd64da0>
While in the spawned process
<multiprocessing.context.SpawnContext object at 0x7f8e02fd0ef0>
Consequently, I think the workers processes are being spawned using fork()
in the single process case and spawn()
in the multiprocessing case.
The documentation states:
- On Unix,
fork()
is the default multiprocessing
start method. Using fork()
, child workers typically can access the dataset
and Python argument functions directly through the cloned address space.
- On Windows,
spawn()
is the default multiprocessing
start method. Using spawn()
, another interpreter is launched which runs your main script, followed by the internal worker function that receives the dataset
, collate_fn
and other arguments through pickle
serialization.
Since I’m running on linux, I’m guessing that with fork
, the cloned address space is working correctly with low overhead, while in the case of the spawn
something about pickling the dataset is causing a big overhead when trying to pass it to the workers.
To check that this difference is causing the issue, I tried running in single GPU mode (so NOT spawning any subprocesses using mp.spawn
, or wrapping with DistributedDataParallel
), but changing the default start method using mp.set_start_method('spawn')
in my main function. This results in the same context:
<multiprocessing.context.SpawnContext object at 0x7f5e3b4a29e8>
And produces the same lag as in the multiprocessing case:
Fetching iterator
Done in time: 52.480613708496094
Fetching iterator
Done in time: 52.47240161895752
Fetching iterator
Done in time: 53.32112169265747
I then took a look at my dataset to see what might be causing problems. Our dataset is stored in h5 form, and is loaded into numpy arrays when the dataset is initialized.
The dataset consists of a number of types of data, some of which are too large to load into memory, and so are stored in memmaps (these memmaps are not initialized until the first call to get_item
):
with h5py.File(self.h5_path, 'r') as h5_file:
self.dataset_length = h5_file["labels"].shape[0]
hdf5_hit_pmt = h5_file["hit_pmt"]
hdf5_hit_time = h5_file["hit_time"]
hdf5_hit_charge = h5_file["hit_charge"]
# initialize memmap param dict
self.pmt_dict = {'shape':hdf5_hit_pmt.shape, 'offset':hdf5_hit_pmt.id.get_offset(), 'dtype':hdf5_hit_pmt.dtype}
self.time_dict = {'shape':hdf5_hit_time.shape, 'offset':hdf5_hit_time.id.get_offset(), 'dtype':hdf5_hit_time.dtype}
self.charge_dict = {'shape':hdf5_hit_charge.shape, 'offset':hdf5_hit_charge.id.get_offset(),'dtype':hdf5_hit_charge.dtype}
The other data is small enough to be loaded into memory, and is stored as attributes of the dataset:
self.labels = np.array(h5_file["labels"])
self.energies = np.array(h5_file["energies"])
self.positions = np.array(h5_file["positions"])
self.angles = np.array(h5_file["angles"])
This latter data seems to be what is causing the issue. If none of this data is stored as part of the dataset, very little lag is observed:
Fetching iterator
Done in time: 1.6751770973205566
Fetching iterator
Done in time: 2.464519500732422
Fetching iterator
Done in time: 2.2533977031707764
Whereas if only the labels is stored, the lag increases:
Fetching iterator
Done in time: 9.610902309417725
Fetching iterator
Done in time: 10.206432342529297
Fetching iterator
Done in time: 9.656260013580322
With more parameters producing more lag.
I then tried moving this initialization step to the same initialization function in get_item
that creates the memmaps (so this information is initialized in the worker functions after they have spawned), and this seems to eliminate most of the lag:
Fetching iterator
Done in time: 1.7485809326171875
Fetching iterator
Done in time: 8.036261796951294
Fetching iterator
Done in time: 3.001110076904297
It does add overhead to the first call to get_item
, but not the 50s lag observed previously. To see this I ran:
print("Fetching iterator...")
val_iter = iter(self.data_loaders["validation"])
time1 = time()
print("Done in time: ", time1 - time0)
print("Fetching data...")
data = next(val_iter)
time0 = time()
print("Done in time: ", time0 - time1)
print("Fetching data...")
data = next(val_iter)
time1 = time()
print("Done in time: ", time1 - time0)
Yielding:
Fetching iterator...
Done in time: 1.7992744445800781
Fetching data...
Done in time: 6.465835094451904
Fetching data...
Done in time: 0.10752415657043457
Which is not great, but probably tolerable. It would be nice to know what is going wrong with storing this data in the __init__
/passing it to the workers if anyone has an idea why this lag is being produced.
DataLoader
Regarding the instantiation of the dataloader, the process is a little convoluted because we’re running using hydra configs and it is made in a series of calls, but I have tried to pare it down into a reasonable example.
We have an engine class which handles training and so has the dataloaders as attributes. In our main_worker_function
, we set this up using the config information using:
# Configure data loaders
for task, task_config in config.tasks.items():
if 'data_loaders' in task_config:
engine.configure_data_loaders(config.data, task_config.data_loaders, is_distributed, config.seed)
def configure_data_loaders(self, data_config, loaders_config, is_distributed, seed):
"""
Set up data loaders from loaders config
"""
for name, loader_config in loaders_config.items():
self.data_loaders[name] = get_data_loader(**data_config, **loader_config, is_distributed=is_distributed, seed=seed)
We use the standard DataLoader class, passing a SubsetRandomSampler on a set of indices and our own dataset.
from hydra.utils import instantiate
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
def get_data_loader(dataset, batch_size, sampler, num_workers, is_distributed, seed, split_path=None, split_key=None, transforms=None):
split_indices = np.load(split_path, allow_pickle=True)[split_key]
sampler = SubsetRandomSampler(split_indices)
dataset = instantiate(dataset, transforms=transforms, is_distributed=is_distributed)
return DataLoader(dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers)
Hydra instantiate simply returns the object specified in the dataset (which comes from the config). I’ve checked and the dataset at this point in the code is just our dataset object.
For details on the dataset, refer to the code in the section above this one. I can clarify details further if this is too complicated.