A couple of times I’ve been training on large datasets where one epoch takes several hours and a 1 in 1,000,000 bug came up at some point during the epoch. Luckily I validate every so number of train steps and at this point I save a checkpoint.
BUT, the problem is that I want to get back to exactly where I was previously. So in my train loop I can say
if train_step <= n:
continue
But then my dataloader is still doing all the I/O and preprocessing for all steps <= n. I then tried this in my dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self):
....
self.skip = True
def __getitem__(self):
if self.skip:
return torch.zeros(1)
# normal I/O and preprocessing
Then in my train loop:
if train_step <= n:
if train_step == n:
dataset.skip = False
continue
But this seems to have no effect.
Is there a way to set a certain step index on a DataLoader prior to iteration?
I suppose so. There are many things I could do. Was just wondering if there was a built-in way to do this easily as I imagine it’s not an uncommon need.
@doem97 sorry but no I didn’t manage to get a good workaround for this. Just ended up dodging the problem at a higher level outside the scope of this thread.
As a workaround for people requesting it, I think you can do it via a custom sampler
class CustomSampler(Sampler[int]):
def __init__(self, data_source):
self.data_source = data_source
self.skip_ctr = 0
def skip(self, n):
self.skip_ctr = n
def __iter__(self):
n = len(self.data_source)
for i in range(self.skip_ctr, n):
yield i
def __len__(self):
return n - self.skip_ctr