import lmdb
import random
import asyncio
import torch.utils.data as data
import time
import threading
import warnings
import torch
import pickle
import numpy as np
class LMDBDataset(data.Dataset):
def __init__(self, root, num_samples, buffer_size=8, num_shards=0, shuffle_shards=False):
self.root = root
self.num_samples = num_samples
self.num_shards = num_shards
self.shuffle_shards = shuffle_shards
self.buffer_size = buffer_size
self.queue = []
self.shard_idxs = list(range(num_shards))
if self.shuffle_shards:
random.shuffle(self.shard_idxs)
self.shard_cursor = 0
for i in range(self.buffer_size):
self.read_shard()
threading.Thread(target=self.ensure_buffer, daemon=True).start()
print("START")
def ensure_buffer(self):
print("ensure_buffer")
while True:
if len(self.queue) != self.buffer_size:
self.read_shard()
time.sleep(0.01)
def read_shard(self):
# print(self.shard_cursor)
self.shard_cursor += 1
if self.shard_cursor == self.num_shards:
self.shard_cursor = 0
if self.shuffle_shards:
random.shuffle(self.shard_idxs)
# time.sleep(0.5)
env = lmdb.open('{}/{:08d}'.format(self.root, self.shard_idxs[self.shard_cursor]))
txn = env.begin()
self.queue.append(txn.cursor())
def __len__(self):
return self.num_samples
# async def query_data(self, idx):
def __getitem__(self, idx):
print(self.shard_cursor)
row = self.queue[0].next()
if row is False:
self.queue.pop(0)
assert(len(self.queue))
row = self.queue[0].next()
time.sleep(0.005)
return 1
train_dataset = LMDBDataset(root='/home/xxx/test_lmdb/imagenet', num_samples=1000000, num_shards=10, shuffle_shards=True)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=256, shuffle=False,
num_workers=4, pin_memory=True)
for i, data in enumerate(train_loader):
print(i, data.shape)
I am trying to create a Dataset that is able to constantly fetch new LMDB shard when the size of the queue falls below the specified buffer_size. It works when it is not wrapped in a DataLoader. However, the function ensure_buffer() does not run when it is wrapped in a DataLoader. Why? How should I make this work?
It is hard to say. But most likely the issue is with the multiprocessing used by the dataloader workers that cause issues with the multithreaded backend you use.
Make sure to use if __name__ == "__main__": properly. And you can test with multiprocessing outside of the dataloader to see if it helps.
Ok. Is there a way to make DataLoader read without stopping? I change the code so that one of the process will just block to read new shard. And with num_workers=4, I would expect the data loader to be able to read uninterrupted? However, it seems like it keeps taking longer every few iterations. Why? Setting larger num_workers doesn’t help too.
import lmdb
import random
import asyncio
import torch.utils.data as data
import time
import threading
import warnings
import torch
import pickle
import numpy as np
class LMDBDataset(data.Dataset):
def __init__(self, root, num_samples, buffer_size=4, num_shards=0, shuffle_shards=False):
self.root = root
self.num_samples = num_samples
self.num_shards = num_shards
self.shuffle_shards = shuffle_shards
self.buffer_size = buffer_size
self.queue = []
self.shard_idxs = list(range(num_shards))
if self.shuffle_shards:
random.shuffle(self.shard_idxs)
self.shard_cursor = 0
for i in range(self.buffer_size):
self.read_shard()
# threading.Thread(target=self.ensure_buffer, daemon=True).start()
# print("START")
# def ensure_buffer(self):
# print("ensure_buffer")
# while True:
# if len(self.queue) != self.buffer_size:
# self.read_shard()
# time.sleep(0.01)
def read_shard(self):
# print(self.shard_cursor)
self.shard_cursor += 1
if self.shard_cursor == self.num_shards:
self.shard_cursor = 0
if self.shuffle_shards:
random.shuffle(self.shard_idxs)
time.sleep(0.5)
env = lmdb.open('{}/{:08d}'.format(self.root, self.shard_idxs[self.shard_cursor]))
txn = env.begin()
self.queue.append(txn.cursor())
def __len__(self):
return self.num_samples
# async def query_data(self, idx):
def __getitem__(self, idx):
# print(self.shard_cursor, len(self.queue))
row = self.queue[0].next()
if row is False:
self.queue.pop(0)
if len(self.queue) < self.buffer_size:
self.read_shard()
# assert(len(self.queue))
row = self.queue[0].next()
time.sleep(0.005)
return 1
train_dataset = LMDBDataset(root='/home/xxx/test_lmdb/imagenet', num_samples=1000000, num_shards=10, shuffle_shards=True)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=256, shuffle=False,
num_workers=4, pin_memory=True)
e = time.time()
for i, data in enumerate(train_loader):
print(i, data.shape, time.time() - e)
e = time.time()
The dataloader loads data from different processes. So it does not stop ever.
The only reason I could see this being problematic here is that you now have multiple processing accessing the same lmbd at the same time. Could that cause locking and waits?
I think it is because reading 256 images just take so long that even with n=4 or 8 workers, it will still block at every n iterations. Not sure how I can speed up. Any ideas?
Usually, when loading from disk, you don’t see these kind of issues. You might want to increase the number of workers until the overall runtime starts going up again. To find the best number here.