SGHMC optimizer and closure()

Hello,

I’m trying to implement a generic optimizer that performs SGHMC (see algorithm 2 in Chen et al.). In this algorithm, we can perform m updates of the parameters and momentum while keeping the same mini-batch. I would like to keep my training and validation loops independant of the optimizer I’m using, and thus I’m trying to add this m loop within the SGHMC optimizer.

Within the step method of my optimizer, I added a loop for t in range(self.trajectory_param) where trajectory_param corresponds to the parameter m (see below).

At the end of each iteration of this loop, I would need to update the gradients, so I decided to call the closure after each update (see the code below the comment #recomputes the gradients ....

Is this the right way to go ?

Thanks for your help.

Edit: this implementation is probably wrong because I’m calling the closure inside the loop over the model parameters

    @torch.no_grad()
    def step(self, closure=None):
        
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            # get parameters
            lr = group["lr"]
            alpha = group["alpha"]
            weight_decay = group["weight_decay"]
            cv_flag = group['cv_flag']
            
            for parameter in group["params"]:

                if parameter.grad is None:
                    continue

                state = self.state[parameter]
                if len(state) == 0:
                    state["iteration"] = 0
                    state["momentum"] = torch.zeros_like(parameter)
                    if cv_flag:
                        p.total_grad = torch.zeros_like(p)
                        p.map_grad = torch.zeros_like(p)
                state["iteration"] += 1

                # get v_k 
                momentum = state["momentum"]

                # grad
                gradient = parameter.grad
                if weight_decay!=0.0:
                    gradient.add_(parameter, alpha=weight_decay)

                if state["iteration"] > self.num_warmup_steps:
                    if self.resample_momentum:
                        state["momentum"] = gradient.new(gradient.size()).normal_(mean=0,std=1)
                        #state["momentum"] = torch.normal(mean=torch.zeros_like(gradient), std=1.0)

                    sigma = np.sqrt(alpha)
                    for t in range(self.trajectory_param):
                        #w_scaled ~ N(0,2*alpha) / sqrt(lr), so that lr * w_scaled = N(0,alpha*lr)
                        sample_t = parameter.new(parameter.size()).normal_(mean=0.0, std=sigma) / np.sqrt(lr)

                        #v_k+1 = (1-alpha)*v_k - lr*grad + sqrt(alpha*lr)*w
                        if cv_flag:
                            d_p.add_(p.map_grad, alpha=-1.0)
                            d_p.add_(p.total_grad, alpha=1.0)
                        momentum_t = momentum.mul_(1.0-alpha).add_(0.5*gradient - sample_t, alpha=-lr)

                        #theta_k+1 = theta_k + v_k
                        parameter.add_(momentum_t)
                        
                        if t>0:
                            # recomputes the gradients -> need to test this
                            with torch.enable_grad():
                                loss = closure()
                            gradient = parameter.grad
                            if weight_decay!=0.0:
                                gradient.add_(parameter, alpha=weight_decay)
                else:
                    momentum_t = momentum.mul_(1.0-alpha).add_(0.5*gradient, alpha=-lr)
                    parameter.add_(momentum_t)

        return loss

Aren’t your loops in the wrong order? I think you want them like:

for t in range(self.trajectory_param):
loss=closure()
for p in parameters:

I checked pyro implementation (used from infer.mcmc.hmc), it seems to agree, doing one closure evalutation per step.

Otherwise, this closure approach should work, in principle. But you should zero gradients between calls.

Hey, thanks for your reply. Indeed, the m loop was in the wrong place. I’m going to try the implementation below. I only need to figure out how to do a single SGD momentum step when the number of iterations is still below the number of warmup steps.

Thanks!

    @torch.no_grad()
    def step(self, closure=None):

        for t in range(self.trajectory_param):

            self.zero_grad()

            loss = None
            if closure is not None:
                with torch.enable_grad():
                    loss = closure()

            for group in self.param_groups:

                # get parameters
                lr = group["lr"]
                alpha = group["alpha"]
                weight_decay = group["weight_decay"]
                cv_flag = group['cv_flag']

                for parameter in group["params"]:

                    if parameter.grad is None:
                        continue

                    state = self.state[parameter]
                    if len(state) == 0:
                        state["iteration"] = 0
                        state["momentum"] = torch.zeros_like(parameter)
                        if cv_flag:
                            parameter.total_grad = torch.zeros_like(parameter)
                            parameter.map_grad = torch.zeros_like(parameter)
                    state["iteration"] += 1

                    # get v_k 
                    momentum = state["momentum"]

                    # grad
                    gradient = parameter.grad
                    if weight_decay!=0.0:
                        gradient.add_(parameter, alpha=weight_decay)

                    if state["iteration"] > self.num_warmup_steps:

                        if self.resample_momentum:
                            state["momentum"] = gradient.new(gradient.size()).normal_(mean=0,std=1)
                            #state["momentum"] = torch.normal(mean=torch.zeros_like(gradient), std=1.0)

                        sigma = np.sqrt(alpha)
                        #w_scaled ~ N(0,2*alpha) / sqrt(lr), so that lr * w_scaled = N(0,alpha*lr)
                        sample_t = parameter.new(parameter.size()).normal_(mean=0.0, std=sigma) / np.sqrt(lr)
                        #v_k+1 = (1-alpha)*v_k - lr*grad + sqrt(alpha*lr)*w
                        if cv_flag:
                            gradient.add_(parameter.map_grad, alpha=-1.0)
                            gradient.add_(parameter.total_grad, alpha=1.0)
                        momentum_t = momentum.mul_(1.0-alpha).add_(0.5*gradient - sample_t, alpha=-lr)
                        #theta_k+1 = theta_k + v_k
                        parameter.add_(momentum_t)
                    else:
                        # single SGD momentum step if still in burnin stage
                        momentum_t = momentum.mul_(1.0-alpha).add_(0.5*gradient, alpha=-lr)
                        parameter.add_(momentum_t)
                        #break ?

        return loss