A couple of times I’ve been training on large datasets where one epoch takes several hours and a 1 in 1,000,000 bug came up at some point during the epoch. Luckily I validate every so number of train steps and at this point I save a checkpoint.
BUT, the problem is that I want to get back to exactly where I was previously. So in my train loop I can say
if train_step <= n: continue
But then my dataloader is still doing all the I/O and preprocessing for all steps <= n. I then tried this in my dataset
class Dataset(torch.utils.data.Dataset): def __init__(self): .... self.skip = True def __getitem__(self): if self.skip: return torch.zeros(1) # normal I/O and preprocessing
Then in my train loop:
if train_step <= n: if train_step == n: dataset.skip = False continue
But this seems to have no effect.
Is there a way to set a certain step index on a DataLoader prior to iteration?