Is my batch accumulation implementation correct?

Hi, I’d like to know if my code for training a model with batch accumulation is correct. Especially the part regarding the loss calculation because I’m not so sure this is the right way to do this.
Here’s my code:

def train (start_epochs, n_epochs, best_acc, train_generator, val_generator, model, optimizer, criterion, checkpoint_path, best_model_path):
  #num_epochs = 25
  since = time.time()

  #best_model_wts = copy.deepcopy(model.state_dict())
  #best_acc = 0.0
  train_loss = []
  val_loss = []
  train_acc = []
  val_acc = []

  batch_accumulation = 8

  for epoch in tqdm(range(start_epochs, n_epochs+1)):

    running_train_loss = 0.0
    running_val_loss = 0.0

    running_train_corrects = 0
    running_val_corrects = 0

    optimizer.zero_grad
    #Training
    model.train()
    for i, (faces, labels) in tqdm(enumerate(train_generator)):
      
      faces = faces.to(device)
      labels = labels.to(device)

      #forward
      outputs = model(faces)

      #predictions of the model determined using the torch.max() function, which returns the index of the maximum value in a tensor.
      _, preds = torch.max(outputs[1], 1)

      #pass the model outputs and the true image labels to the loss function
      loss = criterion(outputs[1], labels)
      #loss = loss / batch_accumulation
      running_train_loss += loss.item()
      # Backprop and Adam optimisation
      loss.backward()
      # Track the accuracy and loss
      running_train_corrects += torch.sum(preds == labels.data)

      if (i+1)% batch_accumulation == 0:
        optimizer.step()
        optimizer.zero_grad # zero the gradient buffers 
       
    # calculate average losses and accuracy  
    epoch_train_loss = running_train_loss / len(train_generator.dataset)
    epoch_train_acc = ((running_train_corrects.double() / len(train_generator.dataset)) * 100)
    train_loss.append(epoch_train_loss)
    train_acc.append(epoch_train_acc)

    print('Train Loss: {:.4f} Train Acc: {:.2f}%'.format(epoch_train_loss, epoch_train_acc))

    #Validation
    with torch.set_grad_enabled(False):
      model.eval()
      for i , (faces_val, labels_val) in tqdm(enumerate(val_generator)):

        faces_val = faces_val.to(device)
        labels_val = labels_val.to(device)
        
        if (i+1)% batch_accumulation == 0:

          outputs_val = model(faces_val)

          _, preds_val = torch.max(outputs_val[1], 1)
          loss_val = criterion(outputs_val[1], labels_val)

          running_val_loss += loss_val.item() 
          #running_val_loss = running_val_loss +((1 /(i+1)) * (loss.item() - running_val_loss))
          running_val_corrects += torch.sum(preds_val == labels_val.data)

    # calculate average losses and accuracy 
    epoch_val_loss = running_val_loss / len(validation_generator.dataset)
    epoch_val_acc = (running_val_corrects.double() / len(validation_generator.dataset)) * 100
    val_loss.append(epoch_val_loss)
    val_acc.append(epoch_val_acc)

    print('Validation Loss: {:.4f} Validation Acc: {:.2f}%'.format(epoch_val_loss, epoch_val_acc))
    
    # create checkpoint variable and add important data
    checkpoint = {
        'epoch': epoch + 1,
        'valid_loss_min': epoch_val_loss,
        'valid_accuracy_best': epoch_val_acc,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        }
          
    # save checkpoint
    save_ckp(checkpoint, False, checkpoint_path, best_model_path)

    if(epoch_val_acc > best_acc):
      save_ckp(checkpoint, True, checkpoint_path, best_model_path)
      best_acc = epoch_val_acc
      #best_model_wts = copy.deepcopy(model.state_dict())
    
    save_stat(train_loss, train_loss)
    save_stat(val_loss, val_loss)
    save_stat(train_acc, train_acc)
    train_acc(val_acc, val_acc)
      
  time_elapsed = time.time() - since
  print('Training complete in {:.0f}m {:.0f}s'.format(
          time_elapsed // 60, time_elapsed % 60))
  print('Best val Acc: {:4f}'.format(best_acc))

  return model, train_loss, val_loss, train_acc, val_acc

I got strange epoch train results (like 456.890) and I’m note sure about the if statement in the validation part.
Any help would be great, thanks!

Hi,

optimizer.zero_grad is a function, you need to call it: optimizer.zero_grad()!
Otherwise, your code looks ok!

1 Like

Thank you. Still I have problems on accuracy on validation set. I got around 25% at each epoch while the training accuracy increases. I set Shuffle = False on val test and Shuffle = True on train set. This is something wrong in how I call model.eval() ?

Hi,

I don’t think there is anything wrong with the way you use eval no.
Note that if you use batchnorm like layers, you can have this kind of behavior where eval behaves differently than training (you check look for batchnorm eval on this forum for more info).

Hi,
I have huge gap between train and validation accuracy. (85% train vs 25% val). I looked up for batchnorm problems and I made some changes but still I didn’t figure out why I got these results?
Can you help me?