I know this is an old post, but just wanted to get this out there in case it helps someone.
I would suggest checking:
skip_lr_sched = (scale > scaler.get_scale())
instead of
skip_lr_sched = (scale != scaler.get_scale())
because according to the docs, scaler.update()
decreases the scale_factor
when optimizer.step()
is skipped, as well as increases the scale_factor
when optimizer.step()
has not been skipped for growth_interval
consecutive iterations.
Simply checking scale != scaler.get_scale()
will return False
even when the scale_factor
is increased (and optimizer.step
has NOT been skipped), which we don’t want.