Iam totally new to PyTorch as Iam transitioning from keras.
Iam about to translate one of my projects, where one epoch is either as long as the number of samples devided by the batch size or as long as the number of iterations defined by a user.
Now Iam wondering how I could accomplish the same thing in PyTorch?
I first create a Datset class and wrap pytorchs Dataloader around it as recommended in this tutorial
The easiest way would be to just iterate over subepochs and break after the number of iterations you want.
for e in range(num_epochs):
it = 0
while it < num_it:
for batch in data_loader:
it += 1
if it >= num_it:
break
....
It’s a tad verbose, but in the end at 5 additional lines of code, you might pretend it’s an enlightenment of the “explicit is better than implicit” variety.
As you note, the proper abstraction to tie into would be the DataLoader. Thus, the perhaps more sophisticated technique is to write your own variation of RandomSampler to pass to the dataloader. This should be simple enough, too, you just need to return an iter over a few distinct randperms in __iter__, similar to what RandomSampler does.
No, the DataLoader is gives an iterator (__iter__) and that terminates (with the usual Python mechanism - StopIteration) the for loop. That’s why the manual solution puts in an outer loop to go over this for loop several times if needed.
Yes, it is. The linked codeline in the RandomSampler generates a random permutation and returns an iter on it. This is done at every start of a for loop.