I’m trying to create my own scheduler class that inherits from torch.optim.lr_scheduler.LambdaLR
. My code looks like this:
class DetectionWarmupScheduler(LambdaLR):
r"""Warm-up learning rate scheduler for detection.
For iteration x, the learning rate variation is given by:
/
| => (warmup_factor * (1 - alpha) + alpha) if x < warmup_iters
<
| => 1 otherwise
\
with alpha = x / warmup_iters.
For this particular scheduler, warmup_factor = 1/warmup_iters, and
warmup_iters can't be greater than the dataloader length.
Args:
optimizer (torch optimizer): optimizer object.
warmup_iters (int): number of warmup iterations.
ndataloader (int): length of the dataloader.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_iters: int,
ndataloader: int,
) -> None:
warmup_factor: float = 1. / warmup_iters
min_warmup_iters: int = min(warmup_iters, ndataloader - 1)
def f(x: int) -> float:
if x >= min_warmup_iters:
return 1
alpha: float = float(x) / min_warmup_iters
return warmup_factor * (1 - alpha) + alpha
super().__init__(optimizer, lr_lambda=f)
The code works as expected, but when I try to run a mypy test , it fails giving the following error:
44: error: Argument "lr_lambda" to "__init__" of "LambdaLR" has incompatible type "Callable[[int], float]"; expected "float"
Reading the docs (and taking a look at the code), argument lr_lambda
to LambdaLR
can be a function or a list of functions, but it is not explicitly typed in the code. I don’t know why it says that the expected type is float.
However, if I copy the definition of the class LambdaLR
to that same file and inherit from there, the problem disappears.
Any clue on what could be the problem?
Thanks