I’m trying to create a Dataset
for a large dataset. The dataset is probably 100-200GB currently partitioned into a few directories and in each directory there is ~50 .pt
files each with approx 4,000,000 rows and a variable number of features (each .pt
file has a different number of features) and each row of these files is a single sample.
Naturally I can’t and don’t want to hold even a fraction of this in memory at any one time, but I want to iterate through one file at a time because this saves any effort of bucketing batches for training an RNN, as each mini-batch will come from the same file and have the same # of features.
Has anyone encountered this sort of problem before? I thought about storing some sort of internal index that determines which .pt
file __getitem__
returns from, cacheing the file internally and then when reaching the maximum length initialising a new index and loading the new indexed file to memory, but I don’t know what impact this will have when using a DistributedSampler
in a DataLoader
constructor to parallelise training.
I have outlined an un-tested prototype Dataset
below
class FlowData(Dataset):
def __init__(self,roots):
super(FlowData,self).__init__()
self.roots = roots
self.files = self._files()
self.N = self._len()
self._parsed = [] # List of already parsed `.pt` files
self._global_index = _internal_state # index `.pt` file
self._local_index = 0 # local index of row in `.pt` file
self.x, self.y = self._load_internal(self.index)
pass
def _load_internal(self,index):
x, y = torch.load(self.files[self.index])
return x, y
def __getitem__(self,index):
if torch.is_tensor(index):
index = index.tolist()
try:
x, y = self.x[self.local_index], self.y[self.local_index]
self.local_index += 1
return x, y
except IndexError:
self.local_index = 0
pass
def _files(self):
files = []
for dir in self.roots:
for file in os.listdir(dir):
files.append(os.path.join(dir,file))
return sorted(files)
def _len(self):
N = 0
for file in self.files:
data = torch.load(file)
N += data[0].shape[0]
return N
def __len__(self):
return self.N