Torchdata w/ DDP. Start of Epoch 2 get MemoryError

Hi All!

I’ve been moving my team over to torchdata and really like it as an interface. It is a nice way to write pipelines and load data in from tarfiles (coming from webdataset).

For DDP, I’ve been following the advice here, and using the ReadingService
rs = SequentialReadingService(dist_rs, mp_rs)
https://pytorch.org/data/beta/dlv2_tutorial.html#multiprocessing-distributed

It uses a fair amount of memory but can easily support 4 workers across 8 GPUs (stabilizing around 150 GB of RAM). However, after completing the first epoch and starting the second, the process memory spikes up and I get a MemoryError. In total I have 500 GB of RAM on the machine, so the epoch restart uses more than double the memory.

My suspicion is that the there is some issue in the way that states are reset and synced across dataloaders. This is supported by the following from the larger logs below:

source_datapipe = request.reset_fn(source_datapipe)

Does anyone have any suggestions on how to approach this? Thank you!

Since each of the 32 processes prints errors, this is a set(list()) of all the errors that were printed.

  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/storage.py", line 747, in __reduce__
    self.run()
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 98, in traverse_dps
  File "/storage/home/ec2-user/AedDataScience/src/aed_data_science/PytorchTraining/TrainingScripts/TrainModel_DDP.py", line 118, in <module>
    res = self._recv_bytes()
    buf = self._recv(4)
    response = self.request_queue.get(block=block)
    d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
    while not context.join():
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/protocol.py", line 94, in get_new_request
    request = protocol.get_new_request(block=blocking_request_get)
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/eventloop.py", line 118, in DataPipeToQueuesLoop
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/serialization.py", line 589, in _legacy_save
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 109, in join
    mp.spawn(MultiGPUexecutor, args=(config.gpu_count, config, port), nprocs=config.gpu_count)
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torchdata/dataloader2/utils/worker.py", line 159, in process_reset_fn
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/_tensor.py", line 208, in __reduce_ex__
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/multiprocessing/process.py", line 108, in run
    _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
    graph = traverse_dps(datapipe)
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/multiprocessing/connection.py", line 931, in wait
Traceback (most recent call last):
    storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
  [Previous line repeated 1 more time]
    fd_event_list = self._selector.poll(timeout)
    torch.save(self, b, _use_new_zipfile_serialization=False)
    ready = selector.select(timeout)
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/serialization.py", line 445, in save
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 67, in _list_connected_datapipes
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/multiprocessing/connection.py", line 216, in recv_bytes
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/iter.py", line 147, in DataPipeBehindQueues
    p.dump(scan_obj)
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/selectors.py", line 416, in select
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/multiprocessing/connection.py", line 379, in _recv
    def __reduce_ex__(self, proto):
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torchdata/dataloader2/communication/iter.py", line 153, in DataPipeBehindQueues
    buf = self._recv_bytes(maxlength)
    chunk = read(handle, remaining)
    source_datapipe = request.reset_fn(source_datapipe)
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/multiprocessing/queues.py", line 103, in get
    self._target(*self._args, **self._kwargs)
    return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 145, in _traverse_helper
    for _ in loop:
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    ready = multiprocessing.connection.wait(
  File "/home/ec2-user/anaconda3/envs/torch_conda/lib/python3.10/site-packages/torch/utils/data/graph.py", line 140, in _traverse_helper

Updates:

I’m on torch 2.0 (using torch.compile), torchdata 0.6.0.

My datapipe has this construction:

  1. FileLister.
  2. shuffle
  3. sharding_filter
  4. metadata and fixed length audio extraction.
  5. shuffle
  6. augment samples.
  7. shuffle

Is this risky? I’m trying to avoid yielding
a. 100 audios from a single long one
b. 10 augmentations sequentially.