Typing issue with LambdaLR scheduler

I’m trying to create my own scheduler class that inherits from torch.optim.lr_scheduler.LambdaLR. My code looks like this:

class DetectionWarmupScheduler(LambdaLR):
    r"""Warm-up learning rate scheduler for detection.

    For iteration x, the learning rate variation is given by:

     |  => (warmup_factor * (1 - alpha) + alpha) if x < warmup_iters
     |  => 1 otherwise

        with alpha = x / warmup_iters.

        For this particular scheduler, warmup_factor = 1/warmup_iters, and
        warmup_iters can't be greater than the dataloader length.

        optimizer (torch optimizer): optimizer object.
        warmup_iters (int): number of warmup iterations.
        ndataloader (int): length of the dataloader.


    def __init__(
        optimizer: torch.optim.Optimizer,
        warmup_iters: int,
        ndataloader: int,
    ) -> None:
        warmup_factor: float = 1. / warmup_iters
        min_warmup_iters: int = min(warmup_iters, ndataloader - 1)

        def f(x: int) -> float:
            if x >= min_warmup_iters:
                return 1
            alpha: float = float(x) / min_warmup_iters
            return warmup_factor * (1 - alpha) + alpha

        super().__init__(optimizer, lr_lambda=f)

The code works as expected, but when I try to run a mypy test , it fails giving the following error:

44: error: Argument "lr_lambda" to "__init__" of "LambdaLR" has incompatible type "Callable[[int], float]"; expected "float"

Reading the docs (and taking a look at the code), argument lr_lambda to LambdaLR can be a function or a list of functions, but it is not explicitly typed in the code. I don’t know why it says that the expected type is float.

However, if I copy the definition of the class LambdaLR to that same file and inherit from there, the problem disappears.

Any clue on what could be the problem?


If you are interested in LR warmup schedulers, try pytorch_warmup.

Colab example | PyPI