Resume iterating dataloader from checkpoint batch_idx

Hi,

I was wondering whether it is possible to resume iterating through a dataloader from a checkpoint.

For example:

    dataloaders_dict = {phase: torch.utils.data.DataLoader(datasets_dict[phase], batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) for phase in ['train']} # make sure shuffling is false incase you restart

if os.path.isdir(args.batchidx_checkpoint):
            checkpoint = torch.load(args.batchidx_checkpoint + 'batches.pt')
            batch_idx = checkpoint['batch_idx'] + 1
     
        else:
            os.makedirs(args.batchidx_checkpoint)
            batch_idx = 0

for batch_idx, (inputs, labels) in enumerate(dataloaders_dict['train']):
            #print(batch_idx)
            # do some feature extraction and save feature arrays in hdf5 file  

I have seen some examples mentioning random sampler, but was wondering if it is really necessary…
I also have about 5 million images to go through, so it may be a large dataset to iterate through. My end goal is to save features for each image using a pretrained network, thus no training is really required. I just want to resume feature extraction from a checkpoint batch id…

Cheers,

2 Likes

Hi,

I don’t think you can do that using the vanilla modules.
But it should be fairly easy to make your own sampler inspired by the SequentialSampler() that is 15 lines of code to add a new argument to pass in a start point with return iter(range(self.start_idx, len(self.data_source))).

1 Like

Hi,

Thanks for the reply. I created a new SequentialSampler class in my script:

class SequentialSampler2():
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(self.start_idx, len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

and try to pass it to sampler in my dataloader, like this:

datasets_dict = {phase: Patches(root= args.root_path + phase, phase = phase) for phase in ['train']}
sampler = {phase: SequentialSampler2(len(datasets_dict[phase])) for phase in ['train']}
dataloaders_dict = {phase: torch.utils.data.DataLoader(datasets_dict[phase], batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers, shuffle=False) for phase in ['train']} # make sure shuffling is false incase you restart

resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():
    param.requires_grad = False


# check batch checkpoint 
if os.path.isdir(args.batchidx_checkpoint):
            checkpoint = torch.load(args.batchidx_checkpoint + 'batches.pt')
            batch_idx = checkpoint['batch_idx'] + 1
   
else:
        os.makedirs(args.batchidx_checkpoint)
        batch_idx = 0

# do some other stuff 

for batch_idx, (inputs40x, inputs20x, inputs5x, paths40x, paths20x, paths5x, labels) in enumerate(dataloaders_dict[phase]):

            inputs40x = inputs40x.to(device)
            inputs20x = inputs20x.to(device)
            inputs5x = inputs5x.to(device)

            labels = labels.to(device)

however, appear to get an error:

ValueError: sampler should be an instance of torch.utils.data.Sampler, but got sampler={'train': <__main__.SequentialSampler2 object at 0x128b2d748>}

Perhaps I’ve misunderstood how to feed in a sampler to the dataloader. Also I’m not sure whether this is the correct approach to consider when restarting a certain batch_idx in the for loop.

It’s just that your sampler should be a subclass of the original torch.utils.data.Sampler as the error states.
You can fix that by setting class SequentialSampler2(torch.utils.data.Sampler): to inherit from it. and add super().__init__() in the first line of your init function to make sure to initialize the parent class.

1 Like

ahh, ok thank you! :slightly_smiling_face:

Why does pytorch not have this working out of the box? Doesn’t this make it much harder to have reproducible ML by not having this easily?

Hi,

I’m sure this could be added but there was no explicit feature request for it.
Could you clarify how this is particularly important for reproducible ML? You often want to start your dataset from the middle?

This is crucial when saving checkpoints (because we have a cluster that only allows to run 1day jobs and then I have to re-start it another day from the exact place).

Reproducibility in Deep Learning is a very big issue and this doesn’t make it better.

I’d make a feature request.

1 Like

I see the point about resuming training.
But do you actually save all the random state when you checkpoint? The sampler won’t continue sampling exactly the same thing otherwise.

Hi all, recently I also have this problem and my Sampler is using WeightedRandomSampler. I tried to added @albanD 's solution but unable to do so with my custom sampler.

class WeightedRandomSampler2(torch.utils.data.Sampler):
    r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
    """

    def __init__(self, weights, num_samples, replacement=True):

      super().__init__()

      if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
              num_samples <= 0:
          raise ValueError("num_samples should be a positive integer "
                           "value, but got num_samples={}".format(num_samples))
      if not isinstance(replacement, bool):
          raise ValueError("replacement should be a boolean value, but got "
                           "replacement={}".format(replacement))
      self.weights = torch.as_tensor(weights, dtype=torch.double)
      self.num_samples = num_samples
      self.replacement = replacement

    def __iter__(self):
      return iter(self.start_idx, torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

    def __len__(self):
      return self.num_samples

That is by adding self.start_idx in __iter__() I received the error when running it in the script:

data_sampler = WeightedRandomSampler2(sample_weights, num_samples=len(sample_weights), replacement=True)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-14-b420c064a884> in <module>()
----> 1 data_sampler = WeightedRandomSampler2(sample_weights, num_samples=len(sample_weights), replacement=True)

<ipython-input-10-3c0a1378cd2d> in __init__(self, weights, num_samples, replacement)
     20     def __init__(self, weights, num_samples, replacement=True):
     21 
---> 22       super().__init__()
     23 
     24       if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or               num_samples <= 0:

TypeError: __init__() missing 1 required positional argument: 'data_source'

Any idea?