How to share data among DataLoader processes to save memory

PyTorch’s data loader uses multiprocessing in Python and each process gets a replica of the dataset. When the dataset is huge, this data replication leads to memory issues.

Normally, multiple processes should use shared memory to share data (unlike threads). I wonder if there is an easy way to share the common data across all the data loading worker processes in PyTorch. Maybe someone has already coded this (I could not find yet).

Thanks.

1 Like

If you are lazily loading the data (which is the common use case, if you are dealing with large datasets), the memory overhead from the copies might be small in comparison to the overall memory usage in the script.
That being said, you could try to use shared arrays as described here instead.

1 Like

See here, @ptrblck

I am facing a similar problem to the one mentioned here. But in my case I want to share a class object (Tree structure object “very large tree”) between workers. I see that Python multiprocessing only supports sharing arrays. Is there a way to share different objects between workers?

Would @Pietro_Cicalese’s approach of using queues work (see his detailed explanation in the linked post)?

Hi @ptrblck I know this thread might be dated, but I wanted to second @Pietro_Cicalese 's observation in the proposed approach (the very last paragraph).

I also observed significant overhead when using built-in Queue approach for multi-processing data loading, predominantly coming from the fact that ConnectionWrapper unpickles the received byte array in here. I see that Connection requires recv to return something pickleable, but byte array is also pickleable. Or is it just an intermediary containing the fd handle/size? Also multiple connections are being established between the processes each requiring to pass the answer_challenge.

That and the fact using multiple smaller tensors in the batch being transfer seem to exacerbate the issue is the reason I wanted to ask (in case you know):

  • What’s the recommended way to share larger sample batches made of multiple tensors using the built-in tools? Or is the only option to build a custom shared-memory based file sharing solution using single producer, single consumer style?

Here’s an excerpt from a profiling session for 50 steps, same number of batches in this case each containing 512 samples:

         2705790 function calls (2412183 primitive calls) in 890.012 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
294176/571    0.227    0.000  889.999    1.559 {built-in method builtins.next}
      571    0.003    0.000  889.997    1.559 combined_loader.py:283(__next__)
      571    0.006    0.000  889.992    1.559 combined_loader.py:110(__next__)
      570    0.010    0.000  889.179    1.560 dataloader.py:625(__next__)
      570    0.009    0.000  889.092    1.560 dataloader.py:1298(_next_data)
      570    0.003    0.000  888.554    1.559 dataloader.py:1265(_get_data)
      585    0.004    0.000  888.551    1.519 dataloader.py:1119(_try_get_data)
      585    0.013    0.000  888.547    1.519 queues.py:98(get)
      570    0.189    0.000  568.985    0.998 {built-in method _pickle.loads}
    11970    0.203    0.000  567.946    0.047 reductions.py:354(rebuild_storage_fd)
    11970    0.101    0.000  566.926    0.047 resource_sharer.py:55(detach)
    11970    0.209    0.000  412.757    0.034 resource_sharer.py:81(get_connection)
    11970    0.115    0.000  411.536    0.034 connection.py:493(Client)
    36480    0.215    0.000  408.430    0.011 connection.py:208(recv_bytes)
    36480    0.250    0.000  408.111    0.011 connection.py:413(_recv_bytes)
    72960    0.444    0.000  407.758    0.006 connection.py:374(_recv)
    72960  407.113    0.006  407.113    0.006 {built-in method posix.read}
      586    0.012    0.000  320.303    0.547 connection.py:917(wait)
      586    0.011    0.000  320.256    0.547 selectors.py:403(select)
      586  320.241    0.546  320.241    0.546 {method 'poll' of 'select.poll' objects}
      585    0.004    0.000  319.511    0.546 connection.py:253(poll)
      585    0.004    0.000  319.506    0.546 connection.py:423(_poll)
    11970    0.186    0.000  314.161    0.026 connection.py:747(answer_challenge)
    11970    0.126    0.000  153.878    0.013 reduction.py:186(recv_handle)
    11970    0.196    0.000  153.427    0.013 reduction.py:153(recvfds)
    11970  153.162    0.013  153.162    0.013 {method 'recvmsg' of '_socket.socket' objects}
    11970    0.156    0.000   96.503    0.008 connection.py:732(deliver_challenge)
    35910    0.299    0.000    1.196    0.000 connection.py:181(send_bytes)
    47880    0.301    0.000    0.987    0.000 connection.py:390(_send_bytes)
    11970    0.085    0.000    0.902    0.000 connection.py:202(send)