Santa algorithm implementation

The following code is my implementation of Santa with SSS algorithm written in here (supplementary material). The current version doesn’t make a trainig process converge independently of learning rate. On the other hand Adam and SGD are good at the NN I used.

Please let me know if there is any place to correct.

class SantaSSS(torch.optim.Optimizer):
    """
    Following to Changyou Chen, David Carlson, Zhe Gan, Chunyuan Li, Lawrence Carin
    "Bridging the Gap between Stochastic Gradient MCMC and Stochastic Optimization"
   (arXiv:1512.07962v3)
    """

    def __init__(self, params, lr, burnin, gamma=1.0, momentum=0.9, eps=1.0e-8):
        """
        parameter
        ---------
        lr : float
            learning rate
        burnin : int
            How many exploration steps will be done.
        gamma : float
            Used to calc inverse temperture like the experiments of the paper.
            beta = 1 / step ** gamma 
        momentum : float
            A smoothing parameter named as sigma in the paper.
        eps : float
            A parameter named as lambda in the paper.
        """

        defaults = dict()
        defaults['lr'] = torch.tensor(lr, dtype=torch.float32)
        defaults['burnin'] = torch.tensor(burnin, dtype=torch.int32)
        defaults['gamma'] = torch.tensor(gamma, dtype=torch.float32)
        defaults['momentum'] = torch.tensor(momentum, dtype=torch.float32)
        defaults['eps'] = torch.tensor(eps, dtype=torch.float32)
        defaults['step'] = torch.tensor(1, dtype=torch.int32)
        defaults['noise'] = torch.distributions.Normal(loc=0.0, scale=1.0)
        defaults['do_burn'] = torch.tensor(1, dtype=torch.uint8)

        super().__init__(params, defaults)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['u'] = group['lr'].sqrt() * group['noise'].sample(p.shape)
                state['v'] = torch.zeros(p.shape, dtype=torch.float32)
                state['g'] = 1.0 / (group['eps'] + state['v']).sqrt()
                state['alpha'] = group['lr'].sqrt() * torch.ones(p.shape, dtype=torch.float32)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            lr = group['lr']
            step = group['step']
            noise = group['noise']
            do_burn = group['do_burn']
            T = 1.0 / step.type(torch.float32) ** group['gamma']

            if do_burn:

                if step > group['burnin']:

                    do_burn.fill_(0)
            
            for p in group['params']:

                if p.grad is None:
                    continue

                grad = p.grad.data

                state = self.state[p]
                u, v, alpha = state['u'], state['v'], state['alpha']

                v[:] = group['momentum'] * v + (1.0 - group['momentum']) * grad ** 2
                g_new = 1.0 / (v.sqrt() + group['eps']).sqrt()

                if do_burn:

                    g_old =  state['g']

                    p.data.add_(g_new * u / 2.0)
                    alpha.add_((u * u - lr * T) / 2.0)
                    u.mul_((-alpha / 2.0).exp())
                    u.add_(-lr * g_new * grad
                           + (2.0 * g_old * lr ** 1.5 * T).sqrt() * noise.sample(p.shape))
                    u.mul_((-alpha / 2.0).exp())
                    alpha.add_((u * u - lr * T) / 2.0)
                    p.data.add_(g_new * u / 2.0)

                    g_old[:] = g_new

                else:

                    p.data.add_(g_new * u / 2.0)
                    u.mul_((-alpha / 2.0).exp())
                    u.add_(-lr * g_new * grad)
                    u.mul_((-alpha / 2.0).exp())
                    p.data.add_(g_new * u / 2.0)

            step.add_(1)

        return loss