# Santa algorithm implementation

The following code is my implementation of Santa with SSS algorithm written in here (supplementary material). The current version doesn’t make a trainig process converge independently of learning rate. On the other hand Adam and SGD are good at the NN I used.

Please let me know if there is any place to correct.

``````class SantaSSS(torch.optim.Optimizer):
"""
Following to Changyou Chen, David Carlson, Zhe Gan, Chunyuan Li, Lawrence Carin
"Bridging the Gap between Stochastic Gradient MCMC and Stochastic Optimization"
(arXiv:1512.07962v3)
"""

def __init__(self, params, lr, burnin, gamma=1.0, momentum=0.9, eps=1.0e-8):
"""
parameter
---------
lr : float
learning rate
burnin : int
How many exploration steps will be done.
gamma : float
Used to calc inverse temperture like the experiments of the paper.
beta = 1 / step ** gamma
momentum : float
A smoothing parameter named as sigma in the paper.
eps : float
A parameter named as lambda in the paper.
"""

defaults = dict()
defaults['lr'] = torch.tensor(lr, dtype=torch.float32)
defaults['burnin'] = torch.tensor(burnin, dtype=torch.int32)
defaults['gamma'] = torch.tensor(gamma, dtype=torch.float32)
defaults['momentum'] = torch.tensor(momentum, dtype=torch.float32)
defaults['eps'] = torch.tensor(eps, dtype=torch.float32)
defaults['step'] = torch.tensor(1, dtype=torch.int32)
defaults['noise'] = torch.distributions.Normal(loc=0.0, scale=1.0)
defaults['do_burn'] = torch.tensor(1, dtype=torch.uint8)

super().__init__(params, defaults)

for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['u'] = group['lr'].sqrt() * group['noise'].sample(p.shape)
state['v'] = torch.zeros(p.shape, dtype=torch.float32)
state['g'] = 1.0 / (group['eps'] + state['v']).sqrt()
state['alpha'] = group['lr'].sqrt() * torch.ones(p.shape, dtype=torch.float32)

def step(self, closure=None):

loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:

lr = group['lr']
step = group['step']
noise = group['noise']
do_burn = group['do_burn']
T = 1.0 / step.type(torch.float32) ** group['gamma']

if do_burn:

if step > group['burnin']:

do_burn.fill_(0)

for p in group['params']:

continue

state = self.state[p]
u, v, alpha = state['u'], state['v'], state['alpha']

v[:] = group['momentum'] * v + (1.0 - group['momentum']) * grad ** 2
g_new = 1.0 / (v.sqrt() + group['eps']).sqrt()

if do_burn:

g_old =  state['g']

alpha.add_((u * u - lr * T) / 2.0)
u.mul_((-alpha / 2.0).exp())
+ (2.0 * g_old * lr ** 1.5 * T).sqrt() * noise.sample(p.shape))
u.mul_((-alpha / 2.0).exp())
alpha.add_((u * u - lr * T) / 2.0)

g_old[:] = g_new

else: