I was going through the cifar10 tutorial and saw:
for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs inputs, labels = data # wrap them in Variable inputs, labels = Variable(inputs), Variable(labels) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.data if i % 2000 == 1999: # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0
and was wondering, how does the
enumerate(trainloader,0) work? I wanted to make sure I was going through the data set correctly and all that jazz but the code seems to be doing funny things. For example:
- It resets the running_loss every 2000 even thought he batch size is 4 and thus it means it resets it to zero every 2000 instead of every epoch (i.e. 2000*2 is not equal to 50,000)
perhaps its just me that misunderstood it (and its just random) because I expect
floor( # data points/batch size ) to be how often the data set was suppose to be updated.
Perhaps what I need to do is just do a running average and only at the very end divide by the total number of data points? ‘enumerate(trainloader, 0)’ does re-start the way it samples the data set each epoch, right? Or how is that doing it? Call
.next() internally and starting again once a whole cylce is done?
In addition, if the batch number doesn’t divide the data set size, how is the last batch handled?