concatDataset make training speed slow down exponentially

Hello. I have several numpy .npy dataset files. Each of them have thousands of samples, with shape like (num_sumples, channels, width, height).

If I use one of the files as training set and load it in Dataset:

class MyDataset(Dataset):
    """npy dataset"""

    def __init__(self, *args):
        self.inputs = np.load(input_file, mmap_mode='r')
        self.labels = np.load(label_file, mmap_mode='r')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        X = torch.from_numpy(np.copy(self.inputs[idx]))
        y = torch.from_numpy(np.copy(self.labels[idx]))
        return X, y

trainset = MyDataset(*args)
trainloader = DataLoader(trainset, batch_size=256)

It trains very fast. But if I use Dataset to load multiple npy files and concat them to a concatDataset, the speed slow down exponentially. Such as 1 epoch with 10 files will take 100x time than training with only 1 file.

How do I speed up the training process? I guess I can load the dataset into memory one file by one file. Is it possible?

Thank you!

Could you check, if you are running out of host memory and are using the swap?

Yes, this would be possible, but you would have to consider how and if shuffling should be working.
E.g. if you are lazily opening each file in the __getitem__ method, I would expect the performance to be really bad.
On the other hand, you could implement a logic, which opens a single numpy file, shuffles the indices, reads all samples, and loads the next file.

In case you are indeed running out of memory, you could try to use np.memmap, which doesn’t read the complete array, but allows you to index it from your disc.

@ptrblck, Could you please share an example of what you mean? Do you mean inside the getitem?

you could implement a logic, which opens a single numpy file, shuffles the indices, reads all samples, and loads the next file

How to handle the index argument for getitem? If have 10 files, each with 1000 images, should my len() be 10 or 10000?