Plotting loss curve

You could use the ImageNet example or the following manual approach:

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        running_loss =+ loss.item() * images.size(0)

    loss_values.append(running_loss / len(train_dataset))

plt.plot(loss_values)

This code would plot a single loss value for each epoch. Would that work?

8 Likes