Is there any way to skip steps in a DataLoader?

A couple of times I’ve been training on large datasets where one epoch takes several hours and a 1 in 1,000,000 bug came up at some point during the epoch. Luckily I validate every so number of train steps and at this point I save a checkpoint.

BUT, the problem is that I want to get back to exactly where I was previously. So in my train loop I can say

if train_step <= n:
    continue

But then my dataloader is still doing all the I/O and preprocessing for all steps <= n. I then tried this in my dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        ....
        self.skip = True

    def __getitem__(self):
        if self.skip:
            return torch.zeros(1)
        # normal I/O and preprocessing

Then in my train loop:

if train_step <= n:
    if train_step == n:
        dataset.skip = False
    continue

But this seems to have no effect.

Is there a way to set a certain step index on a DataLoader prior to iteration?

1 Like

Maybe try to use custom Sampler?

I suppose so. There are many things I could do. Was just wondering if there was a built-in way to do this easily as I imagine it’s not an uncommon need.

Yeah, I would say no built-in way for now. But, we are working on a new design of DataLoader, which IMO will provide this functionality.

1 Like

Hi @Alexander_Soare so finally how do you make it? And work around maybe helpful!
Thanks.

@doem97 sorry but no I didn’t manage to get a good workaround for this. Just ended up dodging the problem at a higher level outside the scope of this thread.

Any latest solution for this? @ejguan

The feature of DataLoader2 to do checkpoint over data-pipeline is still WIP.

As a workaround for people requesting it, I think you can do it via a custom sampler

class CustomSampler(Sampler[int]):
    def __init__(self, data_source):
        self.data_source = data_source
        self.skip_ctr = 0

    def skip(self, n):
        self.skip_ctr = n

    def __iter__(self):
        n = len(self.data_source)
        for i in range(self.skip_ctr, n):
            yield i

    def __len__(self):
        return n - self.skip_ctr