Training for more iterarions

Hi all,

Iam totally new to PyTorch as Iam transitioning from keras.

Iam about to translate one of my projects, where one epoch is either as long as the number of samples devided by the batch size or as long as the number of iterations defined by a user.

Now Iam wondering how I could accomplish the same thing in PyTorch?

I first create a Datset class and wrap pytorchs Dataloader around it as recommended in this tutorial

That works fine for number of iterations = len(samples) /batch size

Now I want to increase the number of iterations, so I guess the generator has to be restarted and shuffle.

Any help on this would be awsome,



The easiest way would be to just iterate over subepochs and break after the number of iterations you want.

for e in range(num_epochs):
  it = 0
  while it < num_it:
    for batch in data_loader:
      it += 1
      if it >= num_it:

It’s a tad verbose, but in the end at 5 additional lines of code, you might pretend it’s an enlightenment of the “explicit is better than implicit” variety.

As you note, the proper abstraction to tie into would be the DataLoader. Thus, the perhaps more sophisticated technique is to write your own variation of RandomSampler to pass to the dataloader. This should be simple enough, too, you just need to return an iter over a few distinct randperms in __iter__, similar to what RandomSampler does.

Best regards


Hi Thomas and thanks for the quick reply!

What happens if the data_loader runs out of batches?

Does it ‘reset’ itself and shuffle the data?



No, the DataLoader is gives an iterator (__iter__) and that terminates (with the usual Python mechanism - StopIteration) the for loop. That’s why the manual solution puts in an outer loop to go over this for loop several times if needed.

Yes, that makes sense.

Do you have any idea if the dataset is shuffled when ‘for batch in data loader’ is repeated?

Thanks again,


Yes, it is. The linked codeline in the RandomSampler generates a random permutation and returns an iter on it. This is done at every start of a for loop.

Awsome, thanks Thomas!