Stochastic Gradient Descent with Warm Restarts

I am in the middle of porting my research to PyTorch (the right DL framework), but I am using the Stochastic Gradient Descent with Warm Restarts and PyTorch does not have a full implementation of it.

So I tried to port this TensorFlow implementation. Could anybody please check if it respects the behavior of a learning rate scheduler in Pytorch, because I “winged” it in the sense that I do not know how the learning rate scheduler and the optimizer interact behind the hood, especially with something like Adam where the learning rate scheduler should only provide an upper bound from my understadning.

You can find the port bellow:

import math
import warnings

from torch.optim.lr_scheduler import _LRScheduler

class CosineDecayRestarts(_LRScheduler):
    def __init__(
        self.first_decay_steps = first_decay_steps
        self.t_mul = t_mul
        self.m_mul = m_mul
        self.alpha = alpha

        super(CosineDecayRestarts, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", DeprecationWarning)

        if self.last_epoch == 0:
            return self.base_lrs

        return [self._calculate_decayed_lr(group['lr']) for group in self.optimizer.param_groups]

    def _calculate_decayed_lr(self, group_lr):
        completed_fraction = self._step_count / self.first_decay_steps

        if not self.t_mul == 1.0:
            i_restart = math.floor(
                math.log(1 - completed_fraction * (1 - self.t_mul)) / math.log(self.t_mul)
            sum_r = (1.0 - self.t_mul ** i_restart) / (1.0 - self.t_mul)
            completed_fraction = (completed_fraction - sum_r) / self.t_mul ** i_restart
            i_restart = math.floor(completed_fraction)
            completed_fraction = completed_fraction - i_restart

        m_fac = self.m_mul ** i_restart
        cosine_decayed = 0.5 * m_fac * (1.0 + math.cos(math.pi * completed_fraction))
        decayed = (1 - self.alpha) * cosine_decayed + self.alpha

        return group_lr * decayed