I looked everywhere and worked on this problem for a few days before solving it. Basically, data structures instantiated using the multiprocessing
package do not behave reliably. For example, things break predictably when you nest shared memory structures within other nested shared memory structures. The ideal way to have asynchronous communication between PyTorch dataloader workers is to use process Queues, which shuttle active child process state information to the next active worker which then in turn shuttles new information to the next. Queues are certainly not elegant but can be made far less prone to breaking parallel processes, as indicated by the torch dev team.
Unfortunately, this was not an option for me given how my code base already stored this kind of information, but if you can modify your code to accept information that is popped from Queue, then you certainly should. My problem was actually quite similar to the one described in the later parts of this post, but I ultimately just wanted to share my solution with the community because of all the trouble reading/writing to shared memory from PyTorch dataloaders caused me (I have cross referenced this to other posts that were similar as well). This is specific to my implementation so you may run into instances where this is still unstable/causes deadlocks. Also keep in mind that I am only including relevant steps for brevity. This was tested in a code base using DP and NOT DDP.
In my application, I needed to use a dynamic probability vector for each parent sample that was being loaded by a worker through its respective sampled PyTorch dataloader ID to retrieve instances that were linked to the batched parent sample (this is for MIL). The issue I was facing was that each worker would use its own probability distribution that it modified within its own state, causing significant overfitting on even larger datasets due to overlapping instance sampling. This also cannot be remedied easily using PyTorch’s default/custom sampler
or batch_sampler
(although samplers run in the main process, they will not natively sync worker states, making it very difficult to use them to sub-sample data like this), so it unfortunately required me to use shared memory structures, which should generally be AVOIDED as I mentioned previously.
Keep in mind that you have to use shared memory tensors in my solution as opposed to other shared memory data structures, as I have found that they result in more predictable child process behavior (i.e. you will have to convert your structs if they are lists, np arrays, etc.). You can nest these shared memory tensors within nested local memory structures; this did not cause any issues in my experiments. You also have to use Reentrant locks (i.e. torch.multiprocessing.Manager().RLock()
), because you will induce deadlocks if you either use Mutex locks
(i.e. torch.multiprocessing.Manager().Lock()
, which wont allow workers to reenter into a lock on a resource, therefore deadlocking) or no locks at all.
The four key steps I used are listed here:
Step 1:
# use multiprocessing manager, import within your custom PyTorch Dataset file
# DONT import the lock directly as it can be buggy
from torch.multiprocessing import Manager
Step 2:
# set the main reentrant lock in _ _init_ _
self.main_lock = Manager().RLock()
Step 3:
# use the share_memory_() method only on the tensor you want to share in _ _init_ _
# or pass it from _ _main_ _ and store it as a class variable in _ _init_ _
# AVOID USING NESTED SHARED MEMORY STRUCTURES CONTAINING OTHER SHARED MEMORY STRUCTURES
self.input_2d_probs_dicts[j][key] = torch.tensor(
self.input_2d_probs_dicts[j][key], dtype=torch.float32).share_memory_()
Step 4:
# perform your operations as efficiently as possible using the lock in _ _getitem_ _
# if your processes are not efficient, you will either destroy the benefits of multiple workers or induce deadlocks
# take care to do this in one concerted step if possible to avoid mismatching between workers and to reduce the number of times workers get locked out
with self.main_lock:
# quickly store the shared tensor values locally, do this instead of accessing the shared memory struct repeatedly!
compartment_list[1][:] = self.input_2d_probs_dicts[comp_source][case_id][:]
# get the nonzeros for later
non_zero = torch.count_nonzero(compartment_list[1]).item()
# check to see if we dont have enough nonzero weights
if non_zero < self.STACK_2D[i]:
# quickly reset the shared mem tensor, then store data locally
self.input_2d_probs_dicts[comp_source][case_id][:] = self.input_2d_probs_dicts_uniform[comp_source][case_id][:]
compartment_list[1][:] = self.input_2d_probs_dicts[comp_source][case_id][:]
# sample the cases, done using torch.multinomial which is not ideal given our compute constraints
stack_id[i], compartment_id[i] = self._stack_sampling(compartment_list, i)
# quickly suppress the target probabilities in the shared tensor
self.input_2d_probs_dicts[comp_source][case_id][compartment_id[i]] = 0
# this code assumes you always have more instances than you are sampling for, which is why we dont reset sampling probabilities before releasing the lock
This worked reliably for me, allowing me to read and write to the shared data structure, without deadlocks, directly from the dataloader worker processes. All of my tests showed that the workers were using the same tensor data, without irregularities. Keep in mind these probability vectors are never seen by the GPUs. You will need to set torch.multiprocessing.set_start_method('spawn')
only once at the beginning of your _ _ main _ _ file if you are using this sort of multiprocessing to interact with your GPUs in DDP (I believe spawn is the default for windows, this was done in linux which defaults to fork, you may also encounter issues with shared memory in DDP). In my use case, all of the shared memory structures were used within each worker instance exclusively.
@ptrblck I have seen you reply to MANY posts regarding this problem. I think this is the best solution if you are forced to read and write to shared memory in a PyTorch dataloader child process without using a Queue, and it seems to work much more reliably than using torch.multiprocessing.Array()
, torch.multiprocessing.Value()
, torch.multiprocessing.dict()
and torch.multiprocessing.list()
, with or without locks, as I have tried both for all of them. It is not immediately obvious to me why that may be, so perhaps this is something that should be directed to the PyTorch dev team. This may help others in the future as well. I’m just glad I got it to work well, I really did not want to reinvent the wheel in my code base!
Excuse all the edits, wanted to make sure it was a very clear and high effort post!