A couple of times I’ve been training on large datasets where one epoch takes several hours and a 1 in 1,000,000 bug came up at some point during the epoch. Luckily I validate every so number of train steps and at this point I save a checkpoint.
BUT, the problem is that I want to get back to exactly where I was previously. So in my train loop I can say
if train_step <= n:
continue
But then my dataloader is still doing all the I/O and preprocessing for all steps <= n. I then tried this in my dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self):
....
self.skip = True
def __getitem__(self):
if self.skip:
return torch.zeros(1)
# normal I/O and preprocessing
Then in my train loop:
if train_step <= n:
if train_step == n:
dataset.skip = False
continue
But this seems to have no effect.
Is there a way to set a certain step index on a DataLoader prior to iteration?