I’m trying to use the distributed data parallel thing with pytorch and I’m running into this error which I haven’t found a solution for elsewhere:
-- Process 2 terminated with the following error: Traceback (most recent call last): File "/afs/csail.mit.edu/u/v/asdf/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap fn(i, *args) File "../multiprocess_model.py", line 5, in train_model model.train(args) File "../train.py", line 227, in train for i, data_dict in enumerate(train_dataloader, 1): File "/afs/csail.mit.edu/u/v/asdf/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 363, in __next__ data = self._next_data() File "/afs/csail.mit.edu/u/v/asdf/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 402, in _next_data index = self._next_index() # may raise StopIteration File "/afs/csail.mit.edu/u/v/asdf/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 357, in _next_index return next(self._sampler_iter) # may raise StopIteration File "/afs/csail.mit.edu/u/v/asdf/miniconda3/lib/python3.8/site-packages/torch/utils/data/sampler.py", line 208, in __iter__ for idx in self.sampler: File "/afs/csail.mit.edu/u/v/asdf/miniconda3/lib/python3.8/site-packages/torch/utils/data/distributed.py", line 84, in __iter__ assert len(indices) == self.num_samples AssertionError
How would I approach fixing this problem?
I can comment out the assertion errors in distributed.py, but I’m hoping there is a cleaner way to do this.