Do I have to manually loop over batches?


so I have a Dataset class for my data. I use a dataloader with it. I then have e.g. a train() function that does the training loop. IIn that training loop I loop over the batches manually.

Is that really the way to go with pyTorch?

As you can see in this little video, the function of the Dataloader is to be able to return directly to you the batches of data.

If your Dataset class implements the method __getitem__(self, index : int) -> instance of your data, the Dataloader is supposed to use it (efficiently) to give you directly the data batches in your loop.

Hi Fancy!

Yes. Pytorch does not offer a general-purpose training-loop function.

I think the issue is that for the most basic use case the training loop
is quite simple and hardly a bother to write by hand. But there are
lots of different things you might want to do in a training loop as you
move beyond the most basic use case so designing a general-purpose
training-loop function becomes messy.

The training loop typically loops over epochs as well as batches.
Do you want to calculate one or more accuracy measures? Do
you want to do so periodically after every n batches or n epochs?
Do you want to periodically calculate losses or accuracies for a
validation set? Do you want to implement early stopping based on
some validation measure? Are you collecting statistics or printing out
progress messages as you train? Are you training multiple networks
with the same batches or training an adversarial network?

These – and more – are some of the features a pre-packaged
training-loop function might be asked to support. Which would
you implement? What would the interface to function look like?

Good luck.

K. Frank