I’m running into an issue related to sending PyTorch tensors over
multiprocessing.Queues. Specifically, I have a system in which multiple background processes generate tensors and put them on individual queues, one queue per process. In the main process, I have one thread for each of these queues responsible for getting the tensors from the respective queue and processing them further. With this system, I occasionally get the following error in one of the consuming threads:
Exception in thread Thread-2: Traceback (most recent call last): File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner self.run() File "/usr/lib/python3.6/threading.py", line 864, in run self._target(*self._args, **self._kwargs) File "<redacted>", line 1262, in _queue_fetcher_thread_fn msg = self._inter_process_queue.get() File "/usr/lib/python3.6/multiprocessing/queues.py", line 113, in get return _ForkingPickler.loads(res) File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 300, in rebuild_storage_fd shared_cache[fd_id(fd)] = StorageWeakRef(storage) File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 60, in __setitem__ self.free_dead_references() File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 67, in free_dead_references for key, storage_ref in list(self.items()): RuntimeError: dictionary changed size during iteration
As you can see, this error comes from PyTorch’s reductions.py where the tensors are unpickled. Specifically, this is the relevant passage (with some comments removed for brevity):
class SharedCache(dict): def __init__(self): self.limit = 128 self._after_fork() register_after_fork(self, SharedCache._after_fork) def _after_fork(self): self.lock = threading.Lock() def __setitem__(self, key, storage_ref): dict.__setitem__(self, key, storage_ref) if len(self) > self.limit: self.free_dead_references() def free_dead_references(self): # Multiple Python threads may call free_dead_references() concurrently. # Without a lock, they may try deleting the same entry multiple times. with self.lock: live = 0 for key, storage_ref in list(self.items()): if storage_ref.expired(): del self[key] else: live += 1 self.limit = max(128, live * 2) shared_cache = SharedCache()
This SharedCache class is essentially a dictionary with some additional functionality on top. Since shared_cache is a global variable, all threads use the same instance of SharedCache. The RuntimeError is raised in the
for-loop in the function
free_dead_references. The only explanation I can come up with of how this line could fail is the following:
It seems that
list(self.items()) is not an atomic operation since
self.items() returns a dictionary view which essentially works like an iterator (see for example here). If that’s the case, it could be that Thread 1 looses the GIL while running
list(self.items()). The GIL could then go to Thread 2, which could call
__setitem__ and modify the dictionary. So once Thread 1 gets the GIL again the dictionary size has changed and a RuntimeError is raised.
As you can see
free_dead_references explicitly uses a
threading.Lock to avoid multithreading issues. However,
__setitem__ does not use this lock, so the above scenario can occur. This suggests that the use of the lock should be moved to
__setitem__ as follows:
def __setitem__(self, key, storage_ref): with self.lock: dict.__setitem__(self, key, storage_ref) if len(self) > self.limit: self.free_dead_references()
So far, I have only seen this error twice during multi-GPU training jobs. I have been trying to replicate this issue with a minimal code example for a while, but haven’t had any luck yet.
Any input on whether my explanation above could be the actual issue would be appreciated. If so, I’m happy to contribute a PR. If not, I’d be very interested to hear your opinion on what might be going on here.