How to run multithreaded Dataset properly in DataLoader?

import lmdb
import random
import asyncio
import torch.utils.data as data
import time
import threading
import warnings
import torch
import pickle
import numpy as np

class LMDBDataset(data.Dataset):
    def __init__(self, root, num_samples, buffer_size=8, num_shards=0, shuffle_shards=False):
        self.root = root
        self.num_samples = num_samples
        self.num_shards = num_shards
        self.shuffle_shards = shuffle_shards
        self.buffer_size = buffer_size

        self.queue = []
        self.shard_idxs = list(range(num_shards))
        if self.shuffle_shards:
            random.shuffle(self.shard_idxs)
        self.shard_cursor = 0
        for i in range(self.buffer_size):
            self.read_shard()
        threading.Thread(target=self.ensure_buffer, daemon=True).start()
        print("START")

    def ensure_buffer(self):
        print("ensure_buffer")
        while True:
            if len(self.queue) != self.buffer_size:
                self.read_shard()
            time.sleep(0.01)

    def read_shard(self):
        # print(self.shard_cursor)
        self.shard_cursor += 1
        if self.shard_cursor == self.num_shards:
            self.shard_cursor = 0
            if self.shuffle_shards:
                random.shuffle(self.shard_idxs)
        # time.sleep(0.5)
        env = lmdb.open('{}/{:08d}'.format(self.root, self.shard_idxs[self.shard_cursor]))
        txn = env.begin()
        self.queue.append(txn.cursor())


    def __len__(self):
        return self.num_samples

    # async def query_data(self, idx):

    def __getitem__(self, idx):
        print(self.shard_cursor)

        row = self.queue[0].next()
        if row is False:
            self.queue.pop(0)
            assert(len(self.queue))
            row = self.queue[0].next()
        time.sleep(0.005)
        return 1

train_dataset = LMDBDataset(root='/home/xxx/test_lmdb/imagenet', num_samples=1000000, num_shards=10, shuffle_shards=True)


train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=256, shuffle=False,
        num_workers=4, pin_memory=True)

for i, data in enumerate(train_loader):
    print(i, data.shape)

I am trying to create a Dataset that is able to constantly fetch new LMDB shard when the size of the queue falls below the specified buffer_size. It works when it is not wrapped in a DataLoader. However, the function ensure_buffer() does not run when it is wrapped in a DataLoader. Why? How should I make this work?

Hi,

It is hard to say. But most likely the issue is with the multiprocessing used by the dataloader workers that cause issues with the multithreaded backend you use.

Make sure to use if __name__ == "__main__": properly. And you can test with multiprocessing outside of the dataloader to see if it helps.

Ok. Is there a way to make DataLoader read without stopping? I change the code so that one of the process will just block to read new shard. And with num_workers=4, I would expect the data loader to be able to read uninterrupted? However, it seems like it keeps taking longer every few iterations. Why? Setting larger num_workers doesn’t help too.


import lmdb
import random
import asyncio
import torch.utils.data as data
import time
import threading
import warnings
import torch
import pickle
import numpy as np

class LMDBDataset(data.Dataset):
    def __init__(self, root, num_samples, buffer_size=4, num_shards=0, shuffle_shards=False):
        self.root = root
        self.num_samples = num_samples
        self.num_shards = num_shards
        self.shuffle_shards = shuffle_shards
        self.buffer_size = buffer_size

        self.queue = []
        self.shard_idxs = list(range(num_shards))
        if self.shuffle_shards:
            random.shuffle(self.shard_idxs)
        self.shard_cursor = 0
        for i in range(self.buffer_size):
            self.read_shard()
        # threading.Thread(target=self.ensure_buffer, daemon=True).start()
        # print("START")

    # def ensure_buffer(self):
    #     print("ensure_buffer")
    #     while True:
    #         if len(self.queue) != self.buffer_size:
    #             self.read_shard()
    #         time.sleep(0.01)

    def read_shard(self):
        # print(self.shard_cursor)
        self.shard_cursor += 1
        if self.shard_cursor == self.num_shards:
            self.shard_cursor = 0
            if self.shuffle_shards:
                random.shuffle(self.shard_idxs)
        time.sleep(0.5)
        env = lmdb.open('{}/{:08d}'.format(self.root, self.shard_idxs[self.shard_cursor]))
        txn = env.begin()
        self.queue.append(txn.cursor())


    def __len__(self):
        return self.num_samples

    # async def query_data(self, idx):

    def __getitem__(self, idx):
        # print(self.shard_cursor, len(self.queue))

        row = self.queue[0].next()
        if row is False:
            self.queue.pop(0)
            if len(self.queue) < self.buffer_size:
                self.read_shard()
            # assert(len(self.queue))
            row = self.queue[0].next()
        time.sleep(0.005)
        return 1

train_dataset = LMDBDataset(root='/home/xxx/test_lmdb/imagenet', num_samples=1000000, num_shards=10, shuffle_shards=True)


train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=256, shuffle=False,
        num_workers=4, pin_memory=True)

e = time.time()
for i, data in enumerate(train_loader):

    print(i, data.shape, time.time() - e)
    e = time.time()



The timings:

0 torch.Size([256]) 3.5883443355560303
1 torch.Size([256]) 0.0035021305084228516
2 torch.Size([256]) 0.00020503997802734375
3 torch.Size([256]) 6.103515625e-05
4 torch.Size([256]) 5.14984130859375e-05
5 torch.Size([256]) 7.700920104980469e-05
6 torch.Size([256]) 0.0006682872772216797
7 torch.Size([256]) 0.0012059211730957031
8 torch.Size([256]) 0.002893209457397461
9 torch.Size([256]) 0.0031805038452148438
10 torch.Size([256]) 0.00018215179443359375
11 torch.Size([256]) 0.00013518333435058594
12 torch.Size([256]) 0.00012803077697753906
13 torch.Size([256]) 0.00012969970703125
14 torch.Size([256]) 0.00012183189392089844
15 torch.Size([256]) 0.0001266002655029297
16 torch.Size([256]) 1.2967519760131836
17 torch.Size([256]) 0.005968332290649414
18 torch.Size([256]) 0.0001418590545654297
19 torch.Size([256]) 0.000179290771484375
20 torch.Size([256]) 0.0007977485656738281
21 torch.Size([256]) 0.00016379356384277344
22 torch.Size([256]) 0.0009293556213378906
23 torch.Size([256]) 0.0010828971862792969
24 torch.Size([256]) 1.8009138107299805
25 torch.Size([256]) 0.0028998851776123047
26 torch.Size([256]) 0.002304553985595703

There are some iterations that just take so much longer (like 2s). What am I doing wrong?

Hi,

The dataloader loads data from different processes. So it does not stop ever.
The only reason I could see this being problematic here is that you now have multiple processing accessing the same lmbd at the same time. Could that cause locking and waits?

I think it is because reading 256 images just take so long that even with n=4 or 8 workers, it will still block at every n iterations. Not sure how I can speed up. :frowning: Any ideas?

Usually, when loading from disk, you don’t see these kind of issues. You might want to increase the number of workers until the overall runtime starts going up again. To find the best number here.