The following code is my implementation of Santa with SSS algorithm written in here (supplementary material). The current version doesn’t make a trainig process converge independently of learning rate. On the other hand Adam and SGD are good at the NN I used.
Please let me know if there is any place to correct.
class SantaSSS(torch.optim.Optimizer):
"""
Following to Changyou Chen, David Carlson, Zhe Gan, Chunyuan Li, Lawrence Carin
"Bridging the Gap between Stochastic Gradient MCMC and Stochastic Optimization"
(arXiv:1512.07962v3)
"""
def __init__(self, params, lr, burnin, gamma=1.0, momentum=0.9, eps=1.0e-8):
"""
parameter
---------
lr : float
learning rate
burnin : int
How many exploration steps will be done.
gamma : float
Used to calc inverse temperture like the experiments of the paper.
beta = 1 / step ** gamma
momentum : float
A smoothing parameter named as sigma in the paper.
eps : float
A parameter named as lambda in the paper.
"""
defaults = dict()
defaults['lr'] = torch.tensor(lr, dtype=torch.float32)
defaults['burnin'] = torch.tensor(burnin, dtype=torch.int32)
defaults['gamma'] = torch.tensor(gamma, dtype=torch.float32)
defaults['momentum'] = torch.tensor(momentum, dtype=torch.float32)
defaults['eps'] = torch.tensor(eps, dtype=torch.float32)
defaults['step'] = torch.tensor(1, dtype=torch.int32)
defaults['noise'] = torch.distributions.Normal(loc=0.0, scale=1.0)
defaults['do_burn'] = torch.tensor(1, dtype=torch.uint8)
super().__init__(params, defaults)
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['u'] = group['lr'].sqrt() * group['noise'].sample(p.shape)
state['v'] = torch.zeros(p.shape, dtype=torch.float32)
state['g'] = 1.0 / (group['eps'] + state['v']).sqrt()
state['alpha'] = group['lr'].sqrt() * torch.ones(p.shape, dtype=torch.float32)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
lr = group['lr']
step = group['step']
noise = group['noise']
do_burn = group['do_burn']
T = 1.0 / step.type(torch.float32) ** group['gamma']
if do_burn:
if step > group['burnin']:
do_burn.fill_(0)
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
u, v, alpha = state['u'], state['v'], state['alpha']
v[:] = group['momentum'] * v + (1.0 - group['momentum']) * grad ** 2
g_new = 1.0 / (v.sqrt() + group['eps']).sqrt()
if do_burn:
g_old = state['g']
p.data.add_(g_new * u / 2.0)
alpha.add_((u * u - lr * T) / 2.0)
u.mul_((-alpha / 2.0).exp())
u.add_(-lr * g_new * grad
+ (2.0 * g_old * lr ** 1.5 * T).sqrt() * noise.sample(p.shape))
u.mul_((-alpha / 2.0).exp())
alpha.add_((u * u - lr * T) / 2.0)
p.data.add_(g_new * u / 2.0)
g_old[:] = g_new
else:
p.data.add_(g_new * u / 2.0)
u.mul_((-alpha / 2.0).exp())
u.add_(-lr * g_new * grad)
u.mul_((-alpha / 2.0).exp())
p.data.add_(g_new * u / 2.0)
step.add_(1)
return loss