Iterating through a Dataloader object

Hello! I saw the following codes today in a LSTM/MNIST example:

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):   # gives batch data
         ......

For these codes, how does the second for loop work (for plus a tuple)? What’s the structure of the DataLoader object here? The for loop also normalizes the MNIST data according to the tutorial? (reference: https://morvanzhou.github.io/tutorials/machine-learning/torch/4-02-RNN-classification/) You may just provide a link here, thanks.

3 Likes

The normalization is usually done in the dataset via the transform argument.

The dataloader provides a Python iterator returning tuples and the enumerate will add the step. You can experience this manually (in Python3):

it = iter(train_loader)
first = next(it)
second = next(it)

will give you the first two things from the train_loader that the for loop would get.
Python Iterators are a concept many people ask and write about in various forums, I don’t know a canonical reference to link to, but searching for “python iterators” you’ll find many things on it.

Finally the step, (x, y) works due to “tuple unpacking”, again a general Python thing.

Best regards

Thomas

8 Likes

The snippet basically tells that, for every epoch the train_loader is invoked which returns x and y say input and its corresponding label. The second for loop is iterating over the entire dataset and the enumerate is simply assigning the i th value to the variable step which corresponds to the i th training example that is loaded. When the train_loader is invoked the training examples that are being loaded are normalized by the DataLoader via the transform argument which allows a combination of data preprocessing steps like resize, normalizing to [0,1], etc.

2 Likes

In cifar10 tutorial at https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html, the following approach is used for obtaining the data and labels:

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(trainloader)

images, labels = dataiter.next()

If we keep calling dataiter.next() repeatedly, will it give error at the point it reaches the end of the dataset, or will the iterator reset to start when the end of the dataset is reached? Also, how will shuffle work in this approach?

Dataloader iter() behaves like any other iterator in python. It raises StopIteration exception when the end is reached.

In pytorch tutorial, after loading the data, iter() followed by next() is used just to get some images and display them in the notebook. In the training loop, a for loop () is used to loop over the training data.

1 Like