Resume training from a specific batch in the dataloader

Hello,
Is there a way to resume training from a specific batch in the data loader.
What I mean is instead of looping the dataloader from the first batch, to start the loop of the dataloader from a given batch “k”. Something like this would work if the dataloader was a simple list:

from torch.utils.data import DataLoader
loader = DataLoader([...])
for batch in loader[k:]:
    # training loop

My use case for this is that I want to resume training, with the same RNG seed for reproducibility, but with one hyperparameter changed.

Thanks for your help!

1 Like

Hi Adam, did you find a way to do this?

Not really…
I tried the solution to fast forward an iterator on the dataloader to a specific batch, but this takes quite some time on large datasets since you have to load all batches in memory.

iterator = iter(loader)
for i in range(len(loader)):
    data = next(iterator)
    while i < k: # k would be the batch index where to resume training
        continue
    # training loop

No worries, thank you