How does 'enumerate(trainloader, 0)' work?

I was going through the cifar10 tutorial and saw:

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.data[0]
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

and was wondering, how does the enumerate(trainloader,0) work? I wanted to make sure I was going through the data set correctly and all that jazz but the code seems to be doing funny things. For example:

  1. It resets the running_loss every 2000 even thought he batch size is 4 and thus it means it resets it to zero every 2000 instead of every epoch (i.e. 2000*2 is not equal to 50,000)

perhaps its just me that misunderstood it (and its just random) because I expect floor( # data points/batch size ) to be how often the data set was suppose to be updated.

Perhaps what I need to do is just do a running average and only at the very end divide by the total number of data points? ‘enumerate(trainloader, 0)’ does re-start the way it samples the data set each epoch, right? Or how is that doing it? Call .next() internally and starting again once a whole cylce is done?

In addition, if the batch number doesn’t divide the data set size, how is the last batch handled?

2 Likes

I assume an epoch of 50,000 samples takes a long time to run. Choosing to calculate and print the average loss for groups 2000 batches is simply a way to get feedback on the training progress more often than once per epoch. Resetting running_loss to zero every now and then has no effect on the training.

for i, data in enumerate(trainloader, 0): restarts the trainloader iterator on each epoch. That is how python iterators work. Let’s take a simpler example for data in trainloader: python starts by calling trainloader.__iter__() to set up the iterator, this returns an object with a .next() method. Then python calls the .next() method of this object in order to get the first and subsequent values used by the for loop. enumerate just wraps the trainloader iterator in order to make it return a counter along with the value at each iteration.

If the batch number doesn’t divide the data set size, then I guess the last batch will simply be smaller than the others.

3 Likes