What's the best way to load large data?

i have a dataset which is about 20G, so i can’t load it directly into RAM.

i create a lmdb database for my data, and i write my own dataset like MNISTdataset in torchvision.

here is my code:

from __future__ import print_function
import torch.utils.data as data
# import h5py
import numpy as np
import lmdb


class onlineHCCR(data.Dataset):
	def __init__(self, train=True):
		# self.root = root
		self.train = train

		if self.train:
			datalmdb_path = 'traindata_lmdb'
			labellmdb_path = 'trainlabel_lmdb'
			self.data_env = lmdb.open(datalmdb_path, readonly=True)
			self.label_env = lmdb.open(labellmdb_path, readonly=True)

		else:
			datalmdb_path = 'testdata_lmdb'
			labellmdb_path = 'testlabel_lmdb'
			self.data_env = lmdb.open(datalmdb_path, readonly=True)
			self.label_env = lmdb.open(labellmdb_path, readonly=True)


	def __getitem__(self, index):

		Data = []
		Target = []

		if self.train:
			with self.data_env.begin() as f:
				key = '{:08}'.format(index)
				data = f.get(key)
				flat_data = np.fromstring(data, dtype=float)
				data = flat_data.reshape(150, 6).astype('float32')
				Data = data

			with self.label_env.begin() as f:
				key = '{:08}'.format(index)
				data = f.get(key)
				label = np.fromstring(data, dtype=int)
				Target = label[0]

		else:

			with self.data_env.begin() as f:
				key = '{:08}'.format(index)
				data = f.get(key)
				flat_data = np.fromstring(data, dtype=float)
				data = flat_data.reshape(150, 6).astype('float32')
				Data = data

			with self.label_env.begin() as f:
				key = '{:08}'.format(index)
				data = f.get(key)
				label = np.fromstring(data, dtype=int)
				Target = label[0]

		return Data, Target
		

	def __len__(self):
		if self.train:
			return 2693931
		else:
			return 224589

but it seems to be very slow. cause a situation where my GPU utilization is too low like 1%, but it do use 1G GPU memory.

how can i solve this problem? what’s the best pratice to load large datasets in pytorch?

lmdb can use multiple workers for loading, have you tried that?
See how I use lmdb for my LSUN dataset here: https://github.com/pytorch/vision/blob/master/torchvision/datasets/lsun.py#L19-L20

I keep readers at 1, but I use multiple workers to load from lsun (and hence lmdb):

2 Likes

UPDATE

It turns out my server was full-occupied when I was running the PyTorch version of my network. It now runs ~100 examples/sec with ~80% GPU utilization. The speed may improve further when the server has more free computation power.

Original answer

The same problem appeared in my case too. But instead of reading from imdb, I am reading the original JPEG file. GPU utilization is constantly lower than 50%. I have implemented my network based on TensorFlow too. In TensorFlow, it can run ~120 examples/sec, while in PyTorch it can only run ~60 examples/sec.

My Dataset implementation is similar to ImageFolder. Multiple workers didn’t help. In fact, the more workers I use, the slower the loading speed. This happens in my TensorFlow implementation too, so only 1 worker is used in my TensorFlow implementation. However, In PyTorch, 0 worker is optimal for me. 1 worker will dramatically decrease the speed to ~30 examples/sec.

2 Likes

I guess one of the reasons is that you are creating a new db transaction each time __getitem__ gets called. That will be a lot of overhead.

Also, sequentially loading the items using an iterator (as done in caffe) will be faster than using f.get(key), but current Dataset API doesn’t seem to support iterator…

What’s your improvement?

Yes, that’s true! It’s much faster when using the cursor of lmdb, and turn off the shuffle choice. You can shuffle the dataset by yourself and read them in order.

Maybe we don’t need the Dataset to support iterator. The API getitem has a index argument. if we turn off the shuffle switch of the dataloader, the index argument can indicate the end of the dataset. in this situation, we can use the lmdb cursor to get our data in order, one by one, it will be much faster, and the GPU utilization will rise.

3 Likes

Yes, that should definitely work. A potential issue with that is
__getitem__ is no longer doing random access as it should be. But
that shouldn’t be too much of a problem as long as one keep aware of that.

Hi, I use that way to load data, but it is still slow.

This is my problem: Quickly loading large data

Do you know why?Thank you.

same with you, but why and how to improve