Does the model retain the same weights at the start of each epoch?

Hi, I have a question about a code snippet below. I’m trying to determine whether the model retains the same weights at the beginning of each epoch? If that’s the case, wouldn’t it be necessary to preserve the weights from the last epoch instead of resetting them at the beginning of each epoch?

def train(): 
      model = MyModel()
      for epoch in range(args.start_epoch, args.epochs):
          train_stats = train_one_epoch(
              model, args, train_config,
              data_loader_train, optimizer, amp_autocast, device, epoch, loss_scaler, 
              start_steps=epoch * num_training_steps_per_epoch,
      if args.output_dir and utils.is_main_process() and (epoch + 1) % args.eval_freq == 0:
          if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
                  args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                  loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
          test_stats = evaluate(data_loader_val, model, device, model_ema=model_ema, args=args)
          print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
          if max_accuracy < test_stats["acc1"]:
              max_accuracy = test_stats["acc1"]
              if args.output_dir:
                      args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                      loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)

          print(f'Max accuracy: {max_accuracy:.2f}%')
          if log_writer is not None:
              log_writer.update(test_acc1=test_stats['acc1'], head="test", step=epoch)
              log_writer.update(test_ema_acc1=test_stats['ema_acc1'], head="test", step=epoch)
          log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                       **{f'test_{k}': v for k, v in test_stats.items()},
                       'epoch': epoch,
                       'n_parameters': n_parameters}
          with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
              f.write(json.dumps(log_stats) + "\n")```

def train_one_epoch(model: torch.nn.Module, 
           args, train_config,  data_loader: Iterable, optimizer: torch.optim.Optimizer, 
            amp_autocast,device: torch.device, epoch: int, loss_scaler, log_writer=None, 
            lr_scheduler=None, start_steps=None, lr_schedule_values=None, model_ema=None):
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    for step, (images , targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # assign learning rate for each step
        it = start_steps + step  # global training iteration
        if lr_schedule_values is not None:
            for i, param_group in enumerate(optimizer.param_groups):
                if lr_schedule_values is not None: 
                    param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
        images =, non_blocking=True)
        targets =, non_blocking=True)
        with amp_autocast():    
            logits = model(images ) 
            # self-training loss    
            loss = F.cross_entropy(logits, targets )
        loss_value = loss.item()
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
        if loss_scaler is not None:
            grad_norm = loss_scaler(loss, optimizer, clip_grad=None, parameters=model.parameters(), create_graph=False)
            loss_scale_value = loss_scaler.state_dict()["scale"]
        min_lr = 10.
        max_lr = 0.
        for group in optimizer.param_groups:
            min_lr = min(min_lr, group["lr"])
            max_lr = max(max_lr, group["lr"])
        if log_writer is not None:            
            log_writer.update(lr=max_lr, head="opt")
            log_writer.update(min_lr=min_lr, head="opt")
        if lr_scheduler is not None:
            lr_scheduler.step_update(start_steps + step)
    # gather the stats from all processes
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}```

Hi Trang!

I haven’t looked at your code.

In the most common use case the model weights will not be the same at
the beginning of each epoch.

When you call opt.step(), the optimizer modifies the model weights so
that the model weights are different at the beginning of every forward pass.
The is nothing in this regard that is different about the first forward pass of
a new epoch.

So unless you have explicit code that resets your model weights to some
specific value between epochs, your model weights at the beginning of a
new epoch will be whatever they were set to by the last opt.step() call
in the previous epoch.

If you are concerned about this, just look through your code. If no code
explicitly resets your model weights between epochs, the weights from the
previous epoch are being automatically “preserved” – there will be no need
for you to do anything special to get this behavior.


K. Frank

1 Like