Thanks for the response. I was hoping to have each worker have a different DB connection to exploit the multiprocessing capabilities. Since the DataLoader
creates a copy a new copy of the Dataset
object, it seemed like a good option. Ofcourse, I know now it doesn’t work.
As you suggested, that can be achieved by moving the DB connection in the main script, outside the Dataset
constructor. In the example below, I create DB connection instances for each of the workers beforehand. This DOES work, however the overhead in using multiple workers is substantial. The overhead is not dependent on the DB connection. The results are the same if I remove the DB connection.
Any thoughts, on how to improve this?
import torch
from torch.utils.data import DataLoader, IterableDataset
from pymongo import MongoClient
from time import time, sleep
def get_conn():
client_db_ip = 'a.b.c.d'
client_username = 'username'
client_pwd = 'password'
client_auth_db = 'test'
client = MongoClient(client_db_ip,
username=client_username,
password=client_pwd,
authSource=client_auth_db)
db = client['db']
collec = db['test_collection']
return collec
# Connection to the MongoDB
connections = [get_conn(), get_conn()]
class MyDataset(IterableDataset):
def __init__(self, start, end):
super().__init__()
self.start = start
self.end = end
def __iter__(self):
sleep(2)
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
worker_id = 0
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
# dummy data from DB connection
dummy_results = connections[worker_id].find().batch_size(1)
return iter(range(iter_start, iter_end))
if __name__ == '__main__':
t0 = time()
NUM_WORKERS = 2
dataset = MyDataset(3, 7)
data_loader = DataLoader(dataset, num_workers=NUM_WORKERS, batch_size=4)
print(list(data_loader))
print(f'elap time for num_workers={NUM_WORKERS}: {time() - t0} s')
There is a 2 second sleep in the __iter__
method.
elap time for num_workers=0: 2.005120038986206 s
elap time for num_workers=1: 8.070252180099487 s
elap time for num_workers=2: 13.47086787223816 s
elap time for num_workers=4: 24.218919038772583 s
I also noticed that when iterating over batches using DataLoader
, each batch is fetched from a different worker. Does that not negate the point of using multiprocessing if each batch is served sequentially by a worker? Or is the intention to have all the worker ‘prepare’ the batch in parallel, but server them sequentially? In other words, there is no join operation after the workers have built their respective batches.