Dataset with unpicklable objects breaks Dataloader where num_workers > 0

Consider the following Dataset and Dataloader objects.

class MyDataset(IterableDataset):
    def __init__(self, start, end):
        super().__init__()
        self.start = start
        self.end = end

       # Connection to the MongoDB
       self.conn = MongoClient()

    def __iter__(self):
        # dummy return
        return iter(range(self.start, self.end))

dataset = MyDataset(3, 7)
data_loader = Dataloader(dataset, num_workers=1)

I’d like to build a custom IterableDataset object that streams data from a database such as MongoDB. I add the MongoDB connection in the constructor of MyDataset, but that returns a pickling error: TypeError: cannot pickle '_thread.lock' object. This is the because the connection object cannot be pickled. The above code would work fine if I remove the MongoDB connection line.

If I’d like to build a dataset object where the data source cannot be pickled (database connection, file IO, etc.), what is the best way to write the data pipeline to read data in parallel?

Since, most large datasets cannot be fit into memory and have to load data in parallel from files, database, etc, how do people avoid this issue?

@ptrblck suggested to use spawn method in this post - DataLoader with num_workers > 0 raising error with mongodb.

Is it possible to set that in the Dataloader API?

You wouldn’t need to set it in the DataLoader directly, but could set it in the main script.
Let me know, if this would work.

Thanks for the response. I was hoping to have each worker have a different DB connection to exploit the multiprocessing capabilities. Since the DataLoader creates a copy a new copy of the Dataset object, it seemed like a good option. Ofcourse, I know now it doesn’t work.

As you suggested, that can be achieved by moving the DB connection in the main script, outside the Dataset constructor. In the example below, I create DB connection instances for each of the workers beforehand. This DOES work, however the overhead in using multiple workers is substantial. The overhead is not dependent on the DB connection. The results are the same if I remove the DB connection.

Any thoughts, on how to improve this?

import torch
from torch.utils.data import DataLoader, IterableDataset
from pymongo import MongoClient
from time import time, sleep


def get_conn():
    client_db_ip = 'a.b.c.d'
    client_username = 'username'
    client_pwd = 'password'
    client_auth_db = 'test'
    client = MongoClient(client_db_ip,
                         username=client_username,
                         password=client_pwd,
                         authSource=client_auth_db)

    db = client['db']
    collec = db['test_collection']
    return collec


# Connection to the MongoDB
connections = [get_conn(), get_conn()]


class MyDataset(IterableDataset):
    def __init__(self, start, end):
        super().__init__()
        self.start = start
        self.end = end

    def __iter__(self):
        sleep(2)
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
            worker_id = 0
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)

        # dummy data from DB connection
        dummy_results = connections[worker_id].find().batch_size(1)

        return iter(range(iter_start, iter_end))


if __name__ == '__main__':
    t0 = time()
    NUM_WORKERS = 2
    dataset = MyDataset(3, 7)
    data_loader = DataLoader(dataset, num_workers=NUM_WORKERS, batch_size=4)

    print(list(data_loader))

    print(f'elap time for num_workers={NUM_WORKERS}: {time() - t0} s')

There is a 2 second sleep in the __iter__ method.

elap time for num_workers=0: 2.005120038986206 s
elap time for num_workers=1: 8.070252180099487 s
elap time for num_workers=2: 13.47086787223816 s
elap time for num_workers=4: 24.218919038772583 s

I also noticed that when iterating over batches using DataLoader, each batch is fetched from a different worker. Does that not negate the point of using multiprocessing if each batch is served sequentially by a worker? Or is the intention to have all the worker ‘prepare’ the batch in parallel, but server them sequentially? In other words, there is no join operation after the workers have built their respective batches.