# WAME Algorithm Implementation Not Working

Hi,
I’ve tried implementing the Weight–wise adaptive learning rates with moving average estimator, WAME, based on this paper, but my loss doesn’t seem to be going down fast enough.
Here’s my implementation:

``````import torch
from torch.optim import Optimizer

class WAME(Optimizer):
"""Implements the Weight–wise Adaptive learning rates with Moving average
Estimator.

Arguments:
params (iterable): iterable of paramenters to optimize
alpha (float, optional): smoothing constant (default: 0.9)
etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
are multiplicative increase and decrease factors
(default: (0.1, 1.2))
step_sizes (Tuple[float, float], optional): a pair of minimal and
maximal allowed step sizes (default: (0.01, 100))
"""

def __init__(self, params, alpha=0.9, etas=(0.1, 1.2), step_sizes=(0.01, 100)):
if not 0.0 <= alpha:
raise ValueError(f"Invalid theta: {alpha}")
if not 0.0 < etas[0] < 1.0 < etas[1]:
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")

defaults = dict(alpha=alpha, etas=etas, step_sizes=step_sizes)
super(WAME, self).__init__(params, defaults)

def __setstate__(self, state):
super(WAME, self).__setstate__(state)

def step(self, closure=None):
"""Perform a single optimization step.

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
continue
raise RuntimeError('WAME does not support sparse gradients')
state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0
state["prev"] = torch.zeros_like(p.data)
state["theta"] = torch.zeros_like(p.data)
state["z"] = torch.zeros_like(p.data)

etaminus, etaplus = group["etas"]
step_size_min, step_size_max = group["step_sizes"]
alpha = group["alpha"]
step_size = state["step_size"]
theta = state["theta"]
z = state["z"]

#print(state)

state["step"] += 1

if mul_dx > 0:
step_size = min(step_size*etaplus, step_size_max)
elif mul_dx < 0:
step_size = max(step_size*etaminus, step_size_min)

z = alpha * z + (1 - alpha) * step_size
theta = alpha * theta + (1 - alpha) * (grad ** 2)

# update paramenters