`optimizer.step()` before `lr_scheduler.step()` error using GradScaler

I know this is an old post, but just wanted to get this out there in case it helps someone.

I would suggest checking:
skip_lr_sched = (scale > scaler.get_scale())
instead of
skip_lr_sched = (scale != scaler.get_scale())

because according to the docs, scaler.update() decreases the scale_factor when optimizer.step() is skipped, as well as increases the scale_factor when optimizer.step() has not been skipped for growth_interval consecutive iterations.

Simply checking scale != scaler.get_scale() will return False even when the scale_factor is increased (and optimizer.step has NOT been skipped), which we don’t want.

5 Likes