RuntimeError in PyTorch's reductions.py when using multiprocessing.Queues with multiple threads

I’m running into an issue related to sending PyTorch tensors over multiprocessing.Queues. Specifically, I have a system in which multiple background processes generate tensors and put them on individual queues, one queue per process. In the main process, I have one thread for each of these queues responsible for getting the tensors from the respective queue and processing them further. With this system, I occasionally get the following error in one of the consuming threads:

Exception in thread Thread-2:
Traceback (most recent call last):
  File "/usr/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "<redacted>", line 1262, in _queue_fetcher_thread_fn
    msg = self._inter_process_queue.get()
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 300, in rebuild_storage_fd
    shared_cache[fd_id(fd)] = StorageWeakRef(storage)
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 60, in __setitem__
    self.free_dead_references()
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/reductions.py", line 67, in free_dead_references
    for key, storage_ref in list(self.items()):
RuntimeError: dictionary changed size during iteration

As you can see, this error comes from PyTorch’s reductions.py where the tensors are unpickled. Specifically, this is the relevant passage (with some comments removed for brevity):

class SharedCache(dict):
    def __init__(self):
        self.limit = 128
        self._after_fork()
        register_after_fork(self, SharedCache._after_fork)

    def _after_fork(self):
        self.lock = threading.Lock()

    def __setitem__(self, key, storage_ref):
        dict.__setitem__(self, key, storage_ref)
        if len(self) > self.limit:
            self.free_dead_references()

    def free_dead_references(self):
        # Multiple Python threads may call free_dead_references() concurrently.
        # Without a lock, they may try deleting the same entry multiple times.
        with self.lock:
            live = 0
            for key, storage_ref in list(self.items()):
                if storage_ref.expired():
                    del self[key]
                else:
                    live += 1
            self.limit = max(128, live * 2)

shared_cache = SharedCache()

This SharedCache class is essentially a dictionary with some additional functionality on top. Since shared_cache is a global variable, all threads use the same instance of SharedCache. The RuntimeError is raised in the for-loop in the function free_dead_references. The only explanation I can come up with of how this line could fail is the following:

It seems that list(self.items()) is not an atomic operation since self.items() returns a dictionary view which essentially works like an iterator (see for example here). If that’s the case, it could be that Thread 1 looses the GIL while running list(self.items()). The GIL could then go to Thread 2, which could call __setitem__ and modify the dictionary. So once Thread 1 gets the GIL again the dictionary size has changed and a RuntimeError is raised.

As you can see free_dead_references explicitly uses a threading.Lock to avoid multithreading issues. However, __setitem__ does not use this lock, so the above scenario can occur. This suggests that the use of the lock should be moved to __setitem__ as follows:

def __setitem__(self, key, storage_ref):
    with self.lock:
        dict.__setitem__(self, key, storage_ref)
        if len(self) > self.limit:
            self.free_dead_references()

So far, I have only seen this error twice during multi-GPU training jobs. I have been trying to replicate this issue with a minimal code example for a while, but haven’t had any luck yet.

Any input on whether my explanation above could be the actual issue would be appreciated. If so, I’m happy to contribute a PR. If not, I’d be very interested to hear your opinion on what might be going on here. :slight_smile:

I think your explanation is correct and some thread/process manipulates the dict, while the loop is still being executed. Here is a minimal code snippet to reproduce this underlying issue:

x = {a: a+10 for a in range(10)}
for key in x:
    print(key, x[key]) # works

for key in x:
    print(key, x[key])
    x[key+100] = key
> RuntimeError: dictionary changed size during iteration

Unfortunately, I don’t know enough about your setup and which process could potentially change the self.items() dict. In any case, you could add debug statements to see when the self.items() objects is being manipulated, which might help in isolating the issue.