Hello! I saw the following codes today in a LSTM/MNIST example:
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data
For these codes, how does the second for loop work (for plus a tuple)? What’s the structure of the DataLoader object here? The for loop also normalizes the MNIST data according to the tutorial? (reference: https://morvanzhou.github.io/tutorials/machine-learning/torch/4-02-RNN-classification/) You may just provide a link here, thanks.
The normalization is usually done in the dataset via the
The dataloader provides a Python iterator returning tuples and the enumerate will add the step. You can experience this manually (in Python3):
it = iter(train_loader)
first = next(it)
second = next(it)
will give you the first two things from the train_loader that the for loop would get.
Python Iterators are a concept many people ask and write about in various forums, I don’t know a canonical reference to link to, but searching for “python iterators” you’ll find many things on it.
step, (x, y) works due to “tuple unpacking”, again a general Python thing.
The snippet basically tells that, for every epoch the train_loader is invoked which returns x and y say input and its corresponding label. The second for loop is iterating over the entire dataset and the enumerate is simply assigning the i th value to the variable step which corresponds to the i th training example that is loaded. When the train_loader is invoked the training examples that are being loaded are normalized by the DataLoader via the transform argument which allows a combination of data preprocessing steps like resize, normalizing to [0,1], etc.
In cifar10 tutorial at https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html, the following approach is used for obtaining the data and labels:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
If we keep calling dataiter.next() repeatedly, will it give error at the point it reaches the end of the dataset, or will the iterator reset to start when the end of the dataset is reached? Also, how will shuffle work in this approach?
iter() behaves like any other iterator in python. It raises
StopIteration exception when the end is reached.
In pytorch tutorial, after loading the data,
iter() followed by
next() is used just to get some images and display them in the notebook. In the training loop, a
for loop () is used to loop over the training data.