We are using simulated training data and developed an IterableDataset for this task. When the dataset is wrapped with a standard DataLoader we can use it for training but suspect that the serial data simulation is severely rate limiting. (complex calculations) This seems to be the perfect scenario for using workers in the DataLoader but we see errors when we attempt to do so.
The issue can be reproduced with a fairly simple example here:
# https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
import torch
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
class RandomDataset(IterableDataset):
def __init__(self):
super(RandomDataset).__init__()
self.rng = np.random.default_rng()
def __next__(self):
x = torch.tensor(self.rng.random(1))
y = 2 * x.clone().detach()
return (x, y)
def __iter__(self):
return self
test_ds = RandomDataset()
test_dl = DataLoader(test_ds, batch_size=20, num_workers=1)
for batch, data in enumerate(test_dl):
print(torch.mean(data[0]))
if batch > 10:
break
Traceback (most recent call last):
File “”, line 1, in
File “~/python3.10/multiprocessing/spawn.py”, line 116, in spawn_main
exitcode = _main(fd, parent_sentinel)
File “~/python3.10/multiprocessing/spawn.py”, line 126, in _main
self = reduction.pickle.load(from_parent)
AttributeError: Can’t get attribute ‘RandomDataset’ on <module ‘main’ (built-in)>
Output exceeds the size limit. Open the full output data in a text editor
RuntimeError Traceback (most recent call last)
File ~/python3.10/site-packages/torch/utils/data/dataloader.py:1163, in _MultiProcessingDataLoaderIter._try_get_data(self, timeout)
1162 try:
→ 1163 data = self._data_queue.get(timeout=timeout)
1164 return (True, data)
File ~/python3.10/multiprocessing/queues.py:113, in Queue.get(self, block, timeout)
112 timeout = deadline - time.monotonic()
→ 113 if not self._poll(timeout):
114 raise Empty
File ~/python3.10/multiprocessing/connection.py:262, in _ConnectionBase.poll(self, timeout)
261 self._check_readable()
→ 262 return self._poll(timeout)
File ~/python3.10/multiprocessing/connection.py:429, in Connection._poll(self, timeout)
428 def _poll(self, timeout):
→ 429 r = wait([self], timeout)
430 return bool(r)
File ~/python3.10/multiprocessing/connection.py:936, in wait(object_list, timeout)
935 while True:
→ 936 ready = selector.select(timeout)
937 if ready:
…
→ 1176 raise RuntimeError(‘DataLoader worker (pid(s) {}) exited unexpectedly’.format(pids_str)) from e
1177 if isinstance(e, queue.Empty):
1178 return (False, None)
RuntimeError: DataLoader worker (pid(s) 19584) exited unexpectedly
#same code above, change workers to zero
test_ds = RandomDataset()
test_dl = DataLoader(test_ds, batch_size=20, num_workers=0)
for batch, data in enumerate(test_dl):
print(torch.mean(data[0]))
if batch > 10:
break
tensor(0.4931, dtype=torch.float64)
tensor(0.4804, dtype=torch.float64)
tensor(0.5428, dtype=torch.float64)
tensor(0.5644, dtype=torch.float64)
tensor(0.5368, dtype=torch.float64)
tensor(0.4914, dtype=torch.float64)
tensor(0.5419, dtype=torch.float64)
tensor(0.5024, dtype=torch.float64)
tensor(0.6744, dtype=torch.float64)
tensor(0.5859, dtype=torch.float64)
tensor(0.4658, dtype=torch.float64)
tensor(0.5022, dtype=torch.float64)