Getting different loss values with gradient accumulation but same effective batch size

I did two experiments one with batch size 5 with accumulation_steps = 2 and another with batch size 10 with accumulation_steps = 1, no other code changes and I expected to see similar loss values on wandb for both experiments but they weren’t. I used the same seed, learning rate (no lr schedule), no dropout, set deterministic to true and data shuffling off, also my model uses instance norm and not batch norm. Here is the code if anyone can help please

optimizer = Adam(model.parameters(), lr=config.training.lr)
scaler = GradScaler()    
counter = 0  
total_loss = 0.0
accumulation_steps = 2
optimizer.zero_grad()
for batch in tqdm(train_loader):   

    y = batch.to(device)
    x = y.sum(1)  
    if config.training.target_instrument is not None:
        i = config.training.instruments.index(config.training.target_instrument)
        y = y[:,i]
    with torch.cuda.amp.autocast():        
        y_ = model(x)   
        loss = nn.MSELoss()(y_, y) 

    scaler.scale(loss).backward()
    if (counter + 1) % accumulation_steps == 0:  
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
    total_loss += loss.item()
    counter += 1
    if counter % 100 == 0:
      average_loss = total_loss / 100
      wandb.log({"loss_100_steps": average_loss})
      total_loss = 0.0

Based on your code you are not dividing the loss (or gradients) by the accumulation steps and are thus directly accumulating the gradients. Disable amp for a quick test and print the .grad attributes of some parameters. In the accumulation_step = 2 case the gradients should have a larger magnitude, which could explain the difference in your loss values.

Thank you @ptrblck

Do you know what code i can add/edit to fix this problem please?
I tried adding loss = loss / accumulation_steps but the loss values are still quite different

Yes, scaling the loss should work. If you get stuck, could you post a minimal and executable code snippet reproducing the unexpected divergence?

Hi,
were you able to find a solution? I am also facing the issue even if loss is normalized