Map-style data loader for multiple directories/files

I’m trying to create a Dataset for a large dataset. The dataset is probably 100-200GB currently partitioned into a few directories and in each directory there is ~50 .pt files each with approx 4,000,000 rows and a variable number of features (each .pt file has a different number of features) and each row of these files is a single sample.

Naturally I can’t and don’t want to hold even a fraction of this in memory at any one time, but I want to iterate through one file at a time because this saves any effort of bucketing batches for training an RNN, as each mini-batch will come from the same file and have the same # of features.

Has anyone encountered this sort of problem before? I thought about storing some sort of internal index that determines which .pt file __getitem__ returns from, cacheing the file internally and then when reaching the maximum length initialising a new index and loading the new indexed file to memory, but I don’t know what impact this will have when using a DistributedSampler in a DataLoader constructor to parallelise training.

I have outlined an un-tested prototype Dataset below

class FlowData(Dataset):
    def __init__(self,roots):
        self.roots = roots
        self.files = self._files()
        self.N = self._len()
        self._parsed = [] # List of already parsed `.pt` files
        self._global_index = _internal_state # index `.pt` file
        self._local_index = 0 # local index of row in `.pt` file
        self.x, self.y = self._load_internal(self.index)

    def _load_internal(self,index):
        x, y = torch.load(self.files[self.index])
        return x, y

    def __getitem__(self,index):
        if torch.is_tensor(index):
            index = index.tolist()
            x, y = self.x[self.local_index], self.y[self.local_index]
            self.local_index += 1
            return x, y
        except IndexError:
            self.local_index = 0

    def _files(self):
        files = []
        for dir in self.roots:
            for file in os.listdir(dir):
        return sorted(files)

    def _len(self):
        N = 0
        for file in self.files:
            data = torch.load(file)
            N += data[0].shape[0]
        return N 
    def __len__(self):
        return self.N