Load 300GB across 200 `.npy` files with Dataset and DataLoader

Hi I have a large dataset (~300GB) stored across ~200 numpy files (.npy). I wrote my own custom Dataset class to load a numpy file and batch it dynamically. I would like to do this because I don’t want to load all ~200 numpy files at once as RAM is limited. Is this is right tool for this use case and if so what do you recommend I do to make this work. Thank you in advance!

from torch.utils.data import Dataset, DataLoader

class MYDS(Dataset):
    def __init__(self, x_paths, y_paths):
        self.xs = None
        self.ys = None

        for x_path, y_path in list(zip(x_paths, y_paths)):
            print(f'loaded: {x_path}, {y_path}')
            self.xs = torch.from_numpy(np.load(x_path))
            self.ys = torch.from_numpy(np.load(y_path))

    def __len__(self):
        return len(self.xs)

    def __getitem__(self, idx):
        return [self.xs[idx], self.ys[idx]]

if __name__=='__main__':
    
    x_paths = ['2018-05_part-00000_train_x.npy', '2018-05_part-00001_train_x.npy', '2018-05_part-00002_train_x.npy'] # small subset of actual dataset
    y_paths = ['2018-05_part-00000_train_y.npy', '2018-05_part-00001_train_y.npy', '2018-05_part-00002_train_y.npy'] # small subset of actual dataset

    ds = MYDS(x_paths, y_paths)
    dl = DataLoader(ds, batch_size=128, shuffle=False, num_workers=2)

    for batch_idx, (x, y) in enumerate(dl):
        print(batch_idx, x[0], y[0]) # loading batches only from x_paths[-1] and y_paths[-1] numpy files

I would go forward numpy memory map. You can create a dataset per numpy file and then just use a concat_dataset to mix them all. That would be the easiest way.

Ok so I could create a Dataset per numpy file then use np.memmap to partially load each numpy file to ensure I don’t load everything into ram all at once. Would also need a DataLoader per Dataset?

Not really, there is a class called dataset_concat already written. It is instantiated passing all the datasets and it will manage choose them as if a Dataset class were so that len_total = sum(len_i) and so on.

https://pytorch.org/docs/stable/data.html?highlight=concat#torch.utils.data.ConcatDataset