Hessian Vector Product for discounted-return-based loss function

Hi all,

I am trying to implement a Hessian-guided SGD optimizer for a RL task, which requires the hessian vector product computed in an online manner. I have tried computing HVP using the following two approaches.

1st approach.

params_with_grad = []
for name, params in self.actor.named_parameters():
    if '_module' in name:
        params_with_grad.append(params)
actor_loss_PW = self.compute_actor_loss(pathwise_grad=True)
        
PW_grads = torch.autograd.grad(actor_loss_PW, params_with_grad,create_graph=True)
num_param = sum(p.numel() for p in params_with_grad)
v = [torch.rand(PW_grads[i].shape).to('cuda') for i in range(len(PW_grads))]
HVP = torch.autograd.grad(PW_grads, params_with_grad, v, retain_graph=True)

2nd approach.

params_with_grad = []
for name, params in self.actor.named_parameters():
    if '_module' in name:
        params_with_grad.append(params)
actor_loss_PW = self.compute_actor_loss(pathwise_grad=True)
        
PW_grads = torch.autograd.grad(actor_loss_PW, params_with_grad,create_graph=True)
PW_grads = torch.cat([e.flatten() for e in PW_grads]) # flatten
num_param = sum(p.numel() for p in params_with_grad)
v = torch.rand(num_param).to('cuda')
HVP = torch.autograd.grad(PW_grads, params_with_grad, v, retain_graph=True)

Both approaches ended up with

*** IndexError: pop from empty list.

The weird thing is there is no further explanation for this error. I wrapped some layers of my actor network with GradSampleModule from Opacus.

UPDATE. I found that removing the GradSampleModule wrapper solves the problem. However, I need GradSampleModule to get individual gradient samples.

The calculation of the loss is as follow

    def compute_actor_loss(self, update_metrics = True, pathwise_grad = True):
        rew_acc = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
        gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device)
        if not pathwise_grad:
            logprob_acc = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            if self.obs_rms is not None:
                obs_rms = deepcopy(self.obs_rms)
        # actor_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
        
        # initialize trajectory to cut off gradients between episodes.
        obs = self.env.initialize_trajectory()
        obs = self._convert_obs(obs)
        # breakpoint()
        if self.obs_rms is not None:
            # update obs rms
            with torch.no_grad():
                for k, v in obs.items():
                    self.obs_rms[k].update(v)
            # normalize the current obs
            obs = {k: obs_rms[k].normalize(v) for k, v in obs.items()}

        # collect trajectories and compute actor loss
        for i in range(self.horizon_len):
            # take env step
            if pathwise_grad:
                actions = self.get_actions(obs, eval = False)
            else:
                actions, logprob = self.get_actions(obs, eval = False, pathwise = False)
            obs, rew, done, extra_info = self.env.step(actions)
            obs = self._convert_obs(obs)

            with torch.no_grad():
                raw_rew = rew.clone()
            # scale the reward
            rew = self.reward_shaper(rew)

            if self.obs_rms is not None:
                # update obs rms
                with torch.no_grad():
                    for k, v in obs.items():
                        self.obs_rms[k].update(v)
                # normalize the current obs
                obs = {k: obs_rms[k].normalize(v) for k, v in obs.items()}

            done_env_ids = done.nonzero(as_tuple=False).squeeze(-1)
            not_done_env_ids = (done == 0).nonzero(as_tuple=False).squeeze(-1)

            # compute actor loss
            rew_acc[not_done_env_ids] = rew_acc[not_done_env_ids] + gamma[not_done_env_ids] * rew[not_done_env_ids]
            if not pathwise_grad:
                logprob_acc[not_done_env_ids] += logprob[not_done_env_ids].sum(dim = 1)

            # compute gamma for next step
            gamma = gamma * self.gamma

            # clear up gamma and rew_acc for done envs
            gamma[done_env_ids] = 1.0

        if pathwise_grad:
            actor_loss = -rew_acc
        else:
            _loss = -rew_acc
            _logprob = logprob_acc
            baseline = _loss.mean()
            adv = (_loss - baseline).detach()
            actor_loss = adv * _logprob
                    
        actor_loss /= self.horizon_len * self.num_envs
        actor_loss = actor_loss.sum()
        return actor_loss

You can get per-sample gradients using the torch.func API (see this), and you can use the torch.autograd.functional API to compute hvp (or vhp - apparently it’s more efficient). So maybe you can try to combine both, instead of relying on Opacus?

Could you also clarify what you want to do, mathematically? In particular:

  1. What vector do you want to use for the hvp? (I’m guessing that the torch.rand in the code you shared is just for the sake of the example).
  2. What do you want to do with the per-sample gradients?

I am trying to compute a hybrid between the score function and pathwise gradient. I sampled one where the reward is detached and gradient only flows through the logprob, store the grad, empty the gradient (for SF grad), then repeat for the PW grad. I wanted to compute the sample variance of each grad - which requires knowing the sample gradients of each trajectory - leading me to using Opacus.

I found that the torch.func.grad/vmaponly works for static dataset (e.g. in supervised/unsupervised setting), whereas my interest is at continuous control/RL. Specifically, this approach requires knowing the networks params and the datasets as inputs, for torch.func.grad, so I did not manage to use it for my interest. If you know how one can do this, I would love to hear it!

In terms of ideaology, I was trying to look at the directional curvature of the loss in the direction of the current gradient, so v would be the current gradient tensor, rather than a random tensor.

UPDATE. I found that I can just still wrap my layers in GradSampleModule and then manually enable/disable the hook by Opacus to overcome this issue.

for i in range(epochs) do:
1. self.actor_mlp.layer_names.enable_hooks()
2. compute and store SF grad samples
3. compute and store PW grad samples
4. self.actor_mlp.layer_names.disable_hooks()
5. compute hvp and directional curvature 
6. update net params
end for

Perhaps the reason is Opacus’s hooks not functioning properly given multiple backward pass (2 in my case). I believe this problem has been discussed somewhere on this forum.