Gradient accumulation implementation for teacher model

Hello, I am trying to perform grdient accumulation in mycode, however, I don’t know where to update the teacher model in my code :slight_smile:

def train_one_epoch(model: torch.nn.Module, args, train_config,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer, amp_autocast,
                    device: torch.device, epoch: int, start_steps=None,
                    lr_schedule_values=None, model_ema=None):
    print(f'--------- epoch : {epoch} ---------')
    model.train()
    #----------------------- for grad accumulation --------------------------------------
    accumulation_steps = 8
    global_step = 0
    # -----------------------------------------------------------------------------------
    for step, ((inputs_u_w, inputs_u_s), targets_u) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # ----------------------- assign learning rate for each step --------------------
        if (step) % accumulation_steps == 0 or step == 0:
            it = start_steps + global_step
            if lr_schedule_values is not None:
                for i, param_group in enumerate(optimizer.param_groups):
                        param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
            model_ema.decay = train_config['model_ema_decay_init'] + (args.model_ema_decay - train_config['model_ema_decay_init']) * min(1, it/train_config['warm_it'])
        # ------------------ Load data into GPU ---------------------------------
        inputs_u_w, inputs_u_s, targets_u = inputs_u_w.to(device), inputs_u_s.to(device), targets_u.to(device)
        #----------------------- For the record -----------------------------------
        total_samples += targets_u.shape[0]
        # -----------------------------------------------------------------------
        with torch.no_grad(): 
            logits_ema, feats_ema = model_ema.ema(inputs_u_w)
            #------------------ pseduo labels based on the prediction -----------
            probs_ema = F.softmax(logits_ema, dim=-1)
            score, pseudo_targets = torch.max(probs_ema, dim=-1)
            conf_mask = (score > train_config['conf_threshold'])
            #----------------------- For the record -----------------------------------
            correct_pred_ema += (pseudo_targets == targets_u).float().sum().item()
            if conf_mask.sum()>0.0:
                pseudo_label_acc += (pseudo_targets[conf_mask] == targets_u[conf_mask]).float().sum().item()
                total_conf += conf_mask.float().sum()
        #---------------------------------------------------------------------
        with amp_autocast(): 
            # --------------- self-training ----------------------------------
            logits_u_s, _ = model(inputs_u_s)
            loss_u = F.cross_entropy(logits_u_s, pseudo_targets)
            #---------------- Total loss --------------------------------------
            loss = (loss_u) / accumulation_steps  
            #----------------------- For the record -----------------------------------
            prob_u_s = F.softmax(logits_u_s.detach(),dim=-1)
            _, max_indices = torch.max(prob_u_s, dim=-1)
            correct_pred_stu_unlabelled += (max_indices == targets_u).float().sum().item()                    
        # ---------------------------------------------------------------------
        loss_value = loss.item()
        #----------------------- For the record -------------------------------
        total_loss += loss_value
        # ---------------------------------------------------------------------
        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            sys.exit(1)
        ## ---------------------- for grad accumulation ------------------------
        loss.backward(create_graph=False)
        if (step + 1) % accumulation_steps == 0:
            global_step+=1
            optimizer.step()
            optimizer.zero_grad()
            model_ema.update(model)
    print("Averaged stats:", metric_logger)
    #----------------------- For the record -----------------------------------
    training_accuracy_stu_unlabelled = (correct_pred_stu_unlabelled / total_samples)*100
    training_accuracy_ema            = (correct_pred_ema / total_samples)*100
    accuracy_pseudo_label            = (pseudo_label_acc / total_conf)*100
    avg_loss                         = total_loss / global_step
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, \
            bank, \
            training_accuracy_stu_unlabelled, training_accuracy_ema, \
            accuracy_pseudo_label, avg_loss\