Best/correct way of caching huge training data stored in chunks in lmdb format?

Hi all,

I’m looking for a way to speed up my current data loading scheme, and would really appreciate any suggestions you might have.

I’m currently trying to build a workflow to hopefully train on up to >100M data points with a total size of >20TB. (each data point is more complicated than just X and y, hence the size. I can clarify more if needed). Of course holding all that data in RAM and ask GPU to fetch data from there is not possible, so in order to do this, I’m currently doing it this way: first pre-process all the data in chunks of 5000 datapoints, and save them as separate lmdb files 1.lmdb, 2.lmdb,... (each containing 5000 datapoints). during training, process the 5000 chunks sequentially. so first load and train with all the datapoints in 1.lmdb, before load and train datapoints in 2.lmdb, which I called the “partial caching scheme”, and the dataset code looks like:

class DatasetPartialCache(Dataset):
    def __init__(
        self,
        db_paths,
    ):
        self.db_paths = db_paths
        self.envs = []
        self.keys_list = []
        self.length_list = []

        for db_path in self.db_paths:
            temp_env = self.connect_db(db_path)
            self.envs.append(temp_env)
            self.keys_list.append(
                [f"{j}".encode("ascii") for j in range(temp_env.stat()["entries"])]
            )

        self._keylen_cumulative = np.cumsum(self.length_list).tolist()
        self.total_length = np.sum(self.length_list)


    def __len__(self):
        return self.total_length

    def __load_dataset__(self, db_idx):
        dataset = []
        with self.envs[db_idx].begin(write=False) as txn:
            for idx in range(self.length_list[db_idx]):
                data = txn.get(self.keys_list[db_idx][idx])
                data_object = pickle.loads(data)
                dataset.append(data_object)

        self.loaded_db_idx = db_idx
        self.loaded_dataset = dataset

    def __getitem__(self, idx):
        db_idx = bisect.bisect(self._keylen_cumulative, idx)
        if db_idx != 0:
            el_idx = idx - self._keylen_cumulative[db_idx - 1]
        else:
            el_idx = idx

        if db_idx != self.loaded_db_idx:
            self.__load_dataset__(db_idx)

        return self.loaded_dataset[el_idx]


    @property
    def input_dim(self):
        return self[0].fingerprint.shape[1]

    def connect_db(self, lmdb_path):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=1,
        )
        return env

The functions to keep an eye on are the getitem function and the load_dataset function. basically, the idea is when it’s first asking for a data point within a lmdb file, load the whole lmdb file to RAM, and the sampler is specially designed to process all the data points within a lmdb file first before moving on to the next lmdb file.

The sampler loos like this in case you are interested:

class PartialCacheSampler(Sampler):
    def __init__(self, length_list, val_frac):
        len_cumulative = np.cumsum(length_list)
        len_dataset = np.sum(length_list)
        len_val = int(len_dataset * val_frac)
        len_train = len_dataset - len_val
        for i, cum_len in enumerate(len_cumulative):
            if cum_len >= len_train:
                self.length_list = length_list[: i + 1]
                self.length_list[-1] -= cum_len - len_train
                break

        self.num_datasets = len(self.length_list)
        self.start_idx_list = [0] + np.cumsum(self.length_list).tolist()
        self.total_length = np.sum(self.length_list)

    def __iter__(self):
        datapoint_order = []
        dataset_order = torch.randperm(self.num_datasets).tolist()
        for dataset_idx in dataset_order:
            start_idx = self.start_idx_list[dataset_idx]
            datapoint_order += [
                i + start_idx
                for i in torch.randperm(self.length_list[dataset_idx]).tolist()
            ]
        return iter(datapoint_order)

basically it randomizes the order of lmdb files for each epoch (10.lmdb > 5.lmdb > 22.lmdb >…), and randomize the orders of datapoint within each lmdb file, but still ask for all the datapoints within each lmdb file first before moving on to the next, to keep the partial cache scheme working. I did test this scheme and it works, so it can deal with an arbitrarily large amount of training data as long as they can be stored on disk.

However, I did a benchmark of this scheme, versus a similar scheme but with all the datapoints in all the lmdb files pre-load into RAM (full cache), and it’s about 3-5x slower. I did a further diagnose and found that the loading time of each lmdb file is ~5s, and the training time on that 5000 datapoints is ~2s, so most of the time GPU was just sitting there idle, waiting for data to be loaded to RAM. so I’m trying to find a way to optimize the dataset, do you have any idea how to improve the current workflow, ideally with a minimal amount of changes? Maybe there are some tools existing for this purpose that I’m not aware of?

One idea I had was having independent processes to “pre-load” the next few lmdb files to RAM while the model is being trained on the current one to eliminate the IO bottleneck. But I have absolutely no clue how to carry this out. Note that the sequence of lmdb files is randomized for every epoch.

Please let me know if you have any thoughts on this. Any suggestion/comment would be greatly appreciated.