Problem
With the code below, I get a learning rate of zero for all iterations when using a small number of training samples, e.g., batch_size=64
, num_train_samples=74
, num_epochs=10
, warmup_epochs=2
. The milestone that I set seems to somehow be wrong.
What I tried
The learning rate adapts as intended for larger trainings, e.g., batch_size=64
, num_train_samples=2328
, num_epochs=2000
, warmup_epochs=20
works fine.
Based on the print statements in get_scheduler()
the variable values are as expected, so the variables seem to not be the problem here. Also, setting warmup_epochs=0
results in the intended constant learning rate.
I suspect there is something wrong with how I set warmup_steps
based on batch_size
, warmup_epochs
and num_train_samples
, but I am at a loss what precisely is going wrong. Does anyone have an idea where my mistake is?
Code
The code is part of a complex codebase, so I tried to provide only the relevant parts of the training process. I also included optimize_parameters, as I suspect there could be something wrong with the way I treat the scaler
.
def warmup_wrapper(warmup_steps: int):
"""We need a closure here to set `warmup_steps`."""
def warmup(current_step: int):
# Linear warmup.
return current_step / warmup_steps
return warmup
def get_scheduler(optimizer, opt: argparse.Namespace):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags
"""
# We do scheduler.step() after each batch, so we need to calculate the actual number of warmup steps
# based on the warmup_epochs given by the user.
if opt.warmup_epochs > 0:
warmup_steps = opt.warmup_epochs * math.ceil(opt.num_train_samples / opt.batch_size)
print(f"warmup_steps: {warmup_steps}")
print(f"warmup_epochs: {opt.warmup_epochs}")
print(f"train_samples: {opt.num_train_samples}")
else:
warmup_steps = 0
main_scheduler = lr_scheduler.ConstantLR(optimizer, factor=1., total_iters=0, last_epoch=-1)
if opt.warmup_epochs > 0:
warmup_fn = warmup_wrapper(warmup_steps)
warmup_scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn)
scheduler = lr_scheduler.SequentialLR(
optimizer, [warmup_scheduler, main_scheduler], milestones=[warmup_steps])
else:
scheduler = main_scheduler
return scheduler
...
class VisionMethod():
...
def optimize_parameters(
self, scaler: torch.cuda.amp.GradScaler, use_mixed_precision: bool
) -> Tuple[torch.cuda.amp.GradScaler, bool]:
"""Perform a forward pass, calculate the losses, and perform the backward pass for the current batch.
"""
device = 'cuda' if 'cuda' in str(self.device) else 'cpu'
with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_mixed_precision):
self.forward()
self.calculate_losses() # sets self.total_loss
self.set_requires_grad([self.segmentation_network], True)
if self.total_loss.requires_grad:
scaler.scale(self.total_loss).backward()
scaler.step(self.optimizer_segmentation)
# We only want to update the learning rate schedulers, if the scaler was updated.
# See https://github.com/pytorch/pytorch/issues/55585
scaler_before = scaler.get_scale()
scaler.update()
scaler_after = scaler.get_scale()
update_schedulers = False
if scaler_before <= scaler_after:
update_schedulers = True
self.optimizer_segmentation.zero_grad()
return scaler, update_schedulers
def train_loop():
...
method = VisionMethod()
method.schedulers = get_scheduler(optimizer, self.opt)
...
for epoch in range(next_epoch, opt.num_epochs + 1):
method.train()
iterations_in_epoch = 0
scaler = torch.cuda.amp.GradScaler(enabled=opt.use_mixed_precision)
with tqdm(
total=len(train_dataset), desc='Training epoch {}/{}'.format(epoch, opt.num_epochs)
) as pbar:
for i, data in enumerate(train_dataset):
iterations_in_epoch += opt.batch_size
total_iterations += opt.batch_size
method.set_input(data)
scaler, update_schedulers = method.optimize_parameters(scaler, opt.use_mixed_precision)
# To prevent "UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`".
if update_schedulers:
method.update_schedulers()
del data
if total_iterations % opt.print_freq < opt.batch_size:
learning_rate =method.optimizer.param_groups[0]['lr'])
method.tensorboardnetworkwriter.add_scalar("Learning rate_{}".format(method.network_name), learning_rate, iteration)
pbar.update(opt.batch_size)