Hello,
I’m trying to implement a generic optimizer that performs SGHMC (see algorithm 2 in Chen et al.). In this algorithm, we can perform m updates of the parameters and momentum while keeping the same mini-batch. I would like to keep my training and validation loops independant of the optimizer I’m using, and thus I’m trying to add this m loop within the SGHMC optimizer.
Within the step
method of my optimizer, I added a loop for t in range(self.trajectory_param)
where trajectory_param
corresponds to the parameter m (see below).
At the end of each iteration of this loop, I would need to update the gradients, so I decided to call the closure
after each update (see the code below the comment #recomputes the gradients ...
.
Is this the right way to go ?
Thanks for your help.
Edit: this implementation is probably wrong because I’m calling the closure inside the loop over the model parameters
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
# get parameters
lr = group["lr"]
alpha = group["alpha"]
weight_decay = group["weight_decay"]
cv_flag = group['cv_flag']
for parameter in group["params"]:
if parameter.grad is None:
continue
state = self.state[parameter]
if len(state) == 0:
state["iteration"] = 0
state["momentum"] = torch.zeros_like(parameter)
if cv_flag:
p.total_grad = torch.zeros_like(p)
p.map_grad = torch.zeros_like(p)
state["iteration"] += 1
# get v_k
momentum = state["momentum"]
# grad
gradient = parameter.grad
if weight_decay!=0.0:
gradient.add_(parameter, alpha=weight_decay)
if state["iteration"] > self.num_warmup_steps:
if self.resample_momentum:
state["momentum"] = gradient.new(gradient.size()).normal_(mean=0,std=1)
#state["momentum"] = torch.normal(mean=torch.zeros_like(gradient), std=1.0)
sigma = np.sqrt(alpha)
for t in range(self.trajectory_param):
#w_scaled ~ N(0,2*alpha) / sqrt(lr), so that lr * w_scaled = N(0,alpha*lr)
sample_t = parameter.new(parameter.size()).normal_(mean=0.0, std=sigma) / np.sqrt(lr)
#v_k+1 = (1-alpha)*v_k - lr*grad + sqrt(alpha*lr)*w
if cv_flag:
d_p.add_(p.map_grad, alpha=-1.0)
d_p.add_(p.total_grad, alpha=1.0)
momentum_t = momentum.mul_(1.0-alpha).add_(0.5*gradient - sample_t, alpha=-lr)
#theta_k+1 = theta_k + v_k
parameter.add_(momentum_t)
if t>0:
# recomputes the gradients -> need to test this
with torch.enable_grad():
loss = closure()
gradient = parameter.grad
if weight_decay!=0.0:
gradient.add_(parameter, alpha=weight_decay)
else:
momentum_t = momentum.mul_(1.0-alpha).add_(0.5*gradient, alpha=-lr)
parameter.add_(momentum_t)
return loss