Gradient Accumulationin pytorch

Hello, I’m implementing gradient accumulation in my code to address memory limitations with large batch sizes. The choice of my loss function depends on the overall probability values across the entire batch. Here’s my current code:

#----------------------- for grad accumulation --------------------------------------
global_step = 0 
# -----------------------------------------------------------------------------------
for step, (x, y, idx) in enumerate(metric_logger.log_every(data_loader_u, print_freq, header)):
    # ----------------------- unlabelled --------------------------------------------
    x = x.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
    with torch.cuda.amp.autocast():
        logit = model(x)
        probs_all = F.softmax(logits,dim=-1) 
        probs_batch_avg = probs_all.mean(0) # average prediction probability across all gpus
        # moving average          
        if global_step==0:
            probs_avg = probs_batch_avg
        else:    
            probs_avg = 0.5*(probs_avg.detach()+probs_batch_avg)
        loss_x = -(torch.log(probs_avg)).mean() / 0.5
        
        #----------------------------------------------------------
        loss = (loss_x) / accumulation_steps
    ## ---------------------- for grad accumulation ------------------------
    loss.backward()
    if (data_iter_step + 1) % accumulation_steps == 0:
        global_step+=1
        optimizer.zero_grad()
    torch.cuda.synchronize()

However, I can’t be sure if the probs_avg is still computed for the accumulated batch, not the minibatch ?