Hey,
since I can generate my training data, I basically have access to unlimited datasets and generate the samples using a torch Dataset and Dataloader on the fly.
But as generating samples is (medium) expensive I want to reuse the generated data for a limited lifetime, let’s say 10 epoches until I would get overfitting effects.
The thing is, when I use single threaded dataloader (i.e. num_workers=0) everything seems to be fine. However as soon as I use multiple workers, they do not seem to have access to write to the dataset, so checking for existing samples fails all the time and new samples need to be generated.
The reason why I like doing this with torch’s dataset / dataloader is that I do not have to care about parallel programming and all of my CPU cores are nicely utilised to generate the data.
Summarised I basically want to do the following:
- Init Dataset class and give it a lifetime of 10 epoches
- for each epoch, for each batch: draw sample from dataset.
if sample is None
, generate a new one and save it as let’s sayself.sample[index] = function_which_generates_samples()
and return otherwise use existing sample - after each epoch: dataset.time_til_death -= 1, if dataset.time_til_death == 0: self.samples = [None] * pseudo_length
If it’s helpful you can find a dummy example below (works for num_workers=0 but not for more)
class DummyDataset(Dataset):
def __init__(self, lifetime=5, ds_size=8):
super().__init__()
self.lifetime = lifetime
self.time_til_death = lifetime
self.ds_size = ds_size
self.samples = None
self.gt = None
self._drop_dataset()
def step(self):
self.time_til_death -= 1
if self.time_til_death <= 0:
self._drop_dataset()
self.time_til_death = self.lifetime
def _drop_dataset(self):
self.samples = [None] * self.__len__()
self.gt = [None] * self.__len__()
print("Dropped.")
def _new_sample(self):
return torch.rand(1), torch.tensor([1])
def __len__(self):
return self.ds_size
def __getitem__(self, index):
if self.samples[index] is not None:
print("Old sample.")
sample = self.samples[index]
gt = self.gt[index]
else:
print("New sample.")
sample, gt = self._new_sample()
self.samples[index] = sample
self.gt[index] = gt
return sample, gt
if __name__ == '__main__':
ds = DummyDataset(lifetime=2, ds_size=2)
dl = DataLoader(ds, batch_size=2, shuffle=True, num_workers=4)
for e in range(3):
print(f"Epoch {e}")
for i, (sample, gt) in enumerate(dl):
print(f'Batch: {i} Sample: {sample} GT {gt}')
# pass
ds.step()