I was going through the cifar10 tutorial and saw:
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.data[0]
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
and was wondering, how does the enumerate(trainloader,0)
work? I wanted to make sure I was going through the data set correctly and all that jazz but the code seems to be doing funny things. For example:
- It resets the running_loss every 2000 even thought he batch size is 4 and thus it means it resets it to zero every 2000 instead of every epoch (i.e. 2000*2 is not equal to 50,000)
perhaps its just me that misunderstood it (and its just random) because I expect floor( # data points/batch size )
to be how often the data set was suppose to be updated.
Perhaps what I need to do is just do a running average and only at the very end divide by the total number of data points? ‘enumerate(trainloader, 0)’ does re-start the way it samples the data set each epoch, right? Or how is that doing it? Call .next()
internally and starting again once a whole cylce is done?
In addition, if the batch number doesn’t divide the data set size, how is the last batch handled?