Periodic oscillations in loss function

I’m training a CNN on ImageNet, and I’m seeing some odd, periodic oscillations in my loss function. The curve looks generally correct, but you can see clear oscillations as the curve flattens out. I’m capturing the loss every 100 mini batches, and since my batch size is 128 and my training set is ~100k images, this means that it takes 10 steps to go through the data. So the periodicity I see here is that the loss increases over the course of 10 steps, and then drops suddenly at the start of a new epoch and then progressively increases until the next epoch is reached.

My training code is below. I’m resetting the “running_loss” to zero at the beginning of each epoch, and at the end of every set of 100 mini-batches. Does anyone have any perspective here? I don’t know whether I’m doing something wrong in the way I’m computing/collecting the loss, or if there’s something more fundamental wrong with my model.

alexnet = AlexNetPyTorch(NUM_CLASSES) # My model class
alexnet = alexnet.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=alexnet.parameters(), lr=3.0e-4, weight_decay=5.0e-4)

train_data = ImageDataset(train_file_list, transform=transform)        
train_data_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=4)

history = {}
history['loss'] = []
history['val_loss'] = []
history['accuracy'] = []
history['val_accuracy'] = []
step_increment = 100
for epoch in range(30):
    
    running_loss = 0.0 # accumulate loss over over all mini batches for a single epoch
    running_val_loss = 0.0
    
    for step, data in enumerate(train_data_loader):
        alexnet.train()
        X_train, y_train = data
        X_train = X_train.to(device)
        y_train = y_train.to(device)

        # forward + backward + optimize
        y_pred = alexnet(X_train)
        loss = loss_fn(y_pred, y_train)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
        if step % step_increment == step_increment-1:    # print every 100 mini-batches
        
            # Compute validation loss
            with torch.no_grad():
                alexnet.eval()

                _, preds = torch.max(y_pred,1)
                train_accuracy = torch.sum(preds == y_train).item()/len(y_train)

                X_val, y_val = validation_data
                X_val = X_val.to(device)
                y_val = y_val.to(device)
                y_pred = alexnet(X_val)
                val_loss = loss_fn(y_pred, y_val).item()

                _, preds = torch.max(y_pred,1)
                val_accuracy = torch.sum(preds == y_val).item()/len(y_val)

                history['loss'].append(np.mean(running_loss/step_increment))
                history['accuracy'].append(train_accuracy)
                history['val_accuracy'].append(val_accuracy)
                history['val_loss'].append(val_loss)

                print('[%d, %5d] loss: %.3f acc: %.3f val loss: %.3f val acc: %.3f' %
                      (epoch + 1, step+1, running_loss/step_increment, train_accuracy, val_loss, val_accuracy))
                
                running_loss = 0.0

Are you shuffling the data?
If not, this might explain these oscillations, e.g. if “harder” samples are near the end.

Yes, I am shuffling the data. My data loader is defined like so:

train_data = ImageDataset(train_file_list, transform=transform)        
train_data_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=4)

Any other ideas out there? I’d love a second opinion.