Resume iterating dataloader from checkpoint batch_idx

Hi,

I was wondering whether it is possible to resume iterating through a dataloader from a checkpoint.

For example:

    dataloaders_dict = {phase: torch.utils.data.DataLoader(datasets_dict[phase], batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) for phase in ['train']} # make sure shuffling is false incase you restart

if os.path.isdir(args.batchidx_checkpoint):
            checkpoint = torch.load(args.batchidx_checkpoint + 'batches.pt')
            batch_idx = checkpoint['batch_idx'] + 1
     
        else:
            os.makedirs(args.batchidx_checkpoint)
            batch_idx = 0

for batch_idx, (inputs, labels) in enumerate(dataloaders_dict['train']):
            #print(batch_idx)
            # do some feature extraction and save feature arrays in hdf5 file  

I have seen some examples mentioning random sampler, but was wondering if it is really necessary…
I also have about 5 million images to go through, so it may be a large dataset to iterate through. My end goal is to save features for each image using a pretrained network, thus no training is really required. I just want to resume feature extraction from a checkpoint batch id…

Cheers,

2 Likes

Hi,

I don’t think you can do that using the vanilla modules.
But it should be fairly easy to make your own sampler inspired by the SequentialSampler() that is 15 lines of code to add a new argument to pass in a start point with return iter(range(self.start_idx, len(self.data_source))).

2 Likes

Hi,

Thanks for the reply. I created a new SequentialSampler class in my script:

class SequentialSampler2():
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(self.start_idx, len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

and try to pass it to sampler in my dataloader, like this:

datasets_dict = {phase: Patches(root= args.root_path + phase, phase = phase) for phase in ['train']}
sampler = {phase: SequentialSampler2(len(datasets_dict[phase])) for phase in ['train']}
dataloaders_dict = {phase: torch.utils.data.DataLoader(datasets_dict[phase], batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, shuffle=False) for phase in ['train']} # make sure shuffling is false incase you restart

resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():
    param.requires_grad = False


# check batch checkpoint 
if os.path.isdir(args.batchidx_checkpoint):
            checkpoint = torch.load(args.batchidx_checkpoint + 'batches.pt')
            batch_idx = checkpoint['batch_idx'] + 1
   
else:
        os.makedirs(args.batchidx_checkpoint)
        batch_idx = 0

# do some other stuff 

for batch_idx, (inputs40x, inputs20x, inputs5x, paths40x, paths20x, paths5x, labels) in enumerate(dataloaders_dict[phase]):

            inputs40x = inputs40x.to(device)
            inputs20x = inputs20x.to(device)
            inputs5x = inputs5x.to(device)

            labels = labels.to(device)

however, appear to get an error:

ValueError: sampler should be an instance of torch.utils.data.Sampler, but got sampler={'train': <__main__.SequentialSampler2 object at 0x128b2d748>}

Perhaps I’ve misunderstood how to feed in a sampler to the dataloader. Also I’m not sure whether this is the correct approach to consider when restarting a certain batch_idx in the for loop.

It’s just that your sampler should be a subclass of the original torch.utils.data.Sampler as the error states.
You can fix that by setting class SequentialSampler2(torch.utils.data.Sampler): to inherit from it. and add super().__init__() in the first line of your init function to make sure to initialize the parent class.

2 Likes

ahh, ok thank you! :slightly_smiling_face:

Why does pytorch not have this working out of the box? Doesn’t this make it much harder to have reproducible ML by not having this easily?

Hi,

I’m sure this could be added but there was no explicit feature request for it.
Could you clarify how this is particularly important for reproducible ML? You often want to start your dataset from the middle?

This is crucial when saving checkpoints (because we have a cluster that only allows to run 1day jobs and then I have to re-start it another day from the exact place).

Reproducibility in Deep Learning is a very big issue and this doesn’t make it better.

I’d make a feature request.

1 Like

I see the point about resuming training.
But do you actually save all the random state when you checkpoint? The sampler won’t continue sampling exactly the same thing otherwise.

Hi all, recently I also have this problem and my Sampler is using WeightedRandomSampler. I tried to added @albanD 's solution but unable to do so with my custom sampler.

class WeightedRandomSampler2(torch.utils.data.Sampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    def __init__(self, weights, num_samples, replacement=True):

      super().__init__()

      if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
              num_samples <= 0:
          raise ValueError("num_samples should be a positive integer "
                           "value, but got num_samples={}".format(num_samples))
      if not isinstance(replacement, bool):
          raise ValueError("replacement should be a boolean value, but got "
                           "replacement={}".format(replacement))
      self.weights = torch.as_tensor(weights, dtype=torch.double)
      self.num_samples = num_samples
      self.replacement = replacement

    def __iter__(self):
      return iter(self.start_idx, torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

    def __len__(self):
      return self.num_samples

That is by adding self.start_idx in __iter__() I received the error when running it in the script:

data_sampler = WeightedRandomSampler2(sample_weights, num_samples=len(sample_weights), replacement=True)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-14-b420c064a884> in <module>()
----> 1 data_sampler = WeightedRandomSampler2(sample_weights, num_samples=len(sample_weights), replacement=True)

<ipython-input-10-3c0a1378cd2d> in __init__(self, weights, num_samples, replacement)
     20     def __init__(self, weights, num_samples, replacement=True):
     21 
---> 22       super().__init__()
     23 
     24       if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or               num_samples <= 0:

TypeError: __init__() missing 1 required positional argument: 'data_source'

Any idea?

Saving all the random state is essential for reproducing and debugging. For reproducing, the author offere the random seed and code, and we readers use the same settings to reproduce the results. If resuming happens in author’s training or readers’ training process, the results are different. You might expect the algorithm to be stable for different random states, but this setting is just for the guarantee of reproducing the results. This could offer us great convenience, and not worrying about the sudden code failure due to reasons like “no space on hard disk”. Moreover, for debugging, it is surely useful when the bug occurs with some certain data after many iterations like within batch 100000. When I save the checkpoint of batch 99999, then I resume training for debug several times, the dataloader is loading the batch 0 for me! This is so inconvenient! So I strongly suggest this feature, and it should be a optional setting, turn on for debugging and reproducing, turn off for simple training. If more questions exist, I’d like to continue discussing.

You’ve mentioned some good ideas! Would you be interested in implementing these debugging utils.?

I am surely interested. As a pytorch heavily user, I’m quite familiar with the APIs. As for implementing, maybe I need some time to first read the source code of the dataloader and the dataset. I am currently reading and working on it, but I’m quite a rookie in this. If any plans or hints exist for implementing this, please help me with it.

You could start by reading through the Contribution Guide and start with a feature request discussing your ideas. After discussion the features with the code owners you could then start implementing these once you are aligned.

Thanks a lot! It’s my first time trying to contribute to pytorch, this is really helpful to me. I’ll start working on this.

1 Like

Hi, is there any updates on this topic, I’m facing the same needs too.

Hello, I found a similar implementation online that seems to solve our problem. vissl/vissl/data/data_helper.py at main · facebookresearch/vissl · GitHub
it looks like we can set_start_iter in StatefulDistributedSampler to implement the stateful resume requirement we want.