Learning rate always zero with SequentialLR


With the code below, I get a learning rate of zero for all iterations when using a small number of training samples, e.g., batch_size=64, num_train_samples=74, num_epochs=10, warmup_epochs=2. The milestone that I set seems to somehow be wrong.

What I tried

The learning rate adapts as intended for larger trainings, e.g., batch_size=64, num_train_samples=2328, num_epochs=2000, warmup_epochs=20 works fine.

Based on the print statements in get_scheduler() the variable values are as expected, so the variables seem to not be the problem here. Also, setting warmup_epochs=0 results in the intended constant learning rate.

I suspect there is something wrong with how I set warmup_steps based on batch_size, warmup_epochs and num_train_samples, but I am at a loss what precisely is going wrong. Does anyone have an idea where my mistake is?


The code is part of a complex codebase, so I tried to provide only the relevant parts of the training process. I also included optimize_parameters, as I suspect there could be something wrong with the way I treat the scaler.

def warmup_wrapper(warmup_steps: int):
    """We need a closure here to set `warmup_steps`."""
    def warmup(current_step: int):
        # Linear warmup.
        return current_step / warmup_steps
    return warmup

def get_scheduler(optimizer, opt: argparse.Namespace):
    """Return a learning rate scheduler
        optimizer          -- the optimizer of the network
        opt (option class) -- stores all the experiment flags
        # We do scheduler.step() after each batch, so we need to calculate the actual number of warmup steps
        # based on the warmup_epochs given by the user.
    if opt.warmup_epochs > 0:
        warmup_steps = opt.warmup_epochs * math.ceil(opt.num_train_samples / opt.batch_size)
        print(f"warmup_steps: {warmup_steps}")
        print(f"warmup_epochs: {opt.warmup_epochs}")
        print(f"train_samples: {opt.num_train_samples}")
        warmup_steps = 0
    main_scheduler = lr_scheduler.ConstantLR(optimizer, factor=1., total_iters=0, last_epoch=-1)
    if opt.warmup_epochs > 0:        
        warmup_fn = warmup_wrapper(warmup_steps)
        warmup_scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn)
        scheduler = lr_scheduler.SequentialLR(
                                    optimizer, [warmup_scheduler, main_scheduler], milestones=[warmup_steps])
        scheduler = main_scheduler
    return scheduler


class VisionMethod():
    def optimize_parameters(
            self, scaler: torch.cuda.amp.GradScaler, use_mixed_precision: bool
     ) -> Tuple[torch.cuda.amp.GradScaler, bool]:
        """Perform a forward pass, calculate the losses, and perform the backward pass for the current batch.
        device = 'cuda' if 'cuda' in str(self.device) else 'cpu'
        with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_mixed_precision):            
            self.calculate_losses() # sets self.total_loss
        self.set_requires_grad([self.segmentation_network], True)        
        if self.total_loss.requires_grad:
        # We only want to update the learning rate schedulers, if the scaler was updated.
        # See https://github.com/pytorch/pytorch/issues/55585
        scaler_before = scaler.get_scale()
        scaler_after = scaler.get_scale()
        update_schedulers = False
        if scaler_before <= scaler_after:
            update_schedulers = True
        return scaler, update_schedulers

def train_loop():
    method = VisionMethod()
    method.schedulers = get_scheduler(optimizer, self.opt)
    for epoch in range(next_epoch, opt.num_epochs + 1):
        iterations_in_epoch = 0

        scaler = torch.cuda.amp.GradScaler(enabled=opt.use_mixed_precision)

        with tqdm(
            total=len(train_dataset), desc='Training epoch {}/{}'.format(epoch, opt.num_epochs)
        ) as pbar:
            for i, data in enumerate(train_dataset):
                iterations_in_epoch += opt.batch_size
                total_iterations += opt.batch_size

                scaler, update_schedulers = method.optimize_parameters(scaler, opt.use_mixed_precision)
                # To prevent "UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`".
                if update_schedulers:
                del data

                if total_iterations % opt.print_freq < opt.batch_size:        
                    learning_rate =method.optimizer.param_groups[0]['lr'])                    
                    method.tensorboardnetworkwriter.add_scalar("Learning rate_{}".format(method.network_name), learning_rate, iteration)