Hello! I saw the following codes today in a LSTM/MNIST example:
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data
......
The normalization is usually done in the dataset via the transform argument.
The dataloader provides a Python iterator returning tuples and the enumerate will add the step. You can experience this manually (in Python3):
it = iter(train_loader)
first = next(it)
second = next(it)
will give you the first two things from the train_loader that the for loop would get.
Python Iterators are a concept many people ask and write about in various forums, I don’t know a canonical reference to link to, but searching for “python iterators” you’ll find many things on it.
Finally the step, (x, y) works due to “tuple unpacking”, again a general Python thing.
The snippet basically tells that, for every epoch the train_loader is invoked which returns x and y say input and its corresponding label. The second for loop is iterating over the entire dataset and the enumerate is simply assigning the i th value to the variable step which corresponds to the i th training example that is loaded. When the train_loader is invoked the training examples that are being loaded are normalized by the DataLoader via the transform argument which allows a combination of data preprocessing steps like resize, normalizing to [0,1], etc.
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
If we keep calling dataiter.next() repeatedly, will it give error at the point it reaches the end of the dataset, or will the iterator reset to start when the end of the dataset is reached? Also, how will shuffle work in this approach?
Dataloader iter() behaves like any other iterator in python. It raises StopIteration exception when the end is reached.
In pytorch tutorial, after loading the data, iter() followed by next() is used just to get some images and display them in the notebook. In the training loop, a for loop () is used to loop over the training data.