Hi all,
I am trying to implement a Hessian-guided SGD optimizer for a RL task, which requires the hessian vector product computed in an online manner. I have tried computing HVP using the following two approaches.
1st approach.
params_with_grad = []
for name, params in self.actor.named_parameters():
if '_module' in name:
params_with_grad.append(params)
actor_loss_PW = self.compute_actor_loss(pathwise_grad=True)
PW_grads = torch.autograd.grad(actor_loss_PW, params_with_grad,create_graph=True)
num_param = sum(p.numel() for p in params_with_grad)
v = [torch.rand(PW_grads[i].shape).to('cuda') for i in range(len(PW_grads))]
HVP = torch.autograd.grad(PW_grads, params_with_grad, v, retain_graph=True)
2nd approach.
params_with_grad = []
for name, params in self.actor.named_parameters():
if '_module' in name:
params_with_grad.append(params)
actor_loss_PW = self.compute_actor_loss(pathwise_grad=True)
PW_grads = torch.autograd.grad(actor_loss_PW, params_with_grad,create_graph=True)
PW_grads = torch.cat([e.flatten() for e in PW_grads]) # flatten
num_param = sum(p.numel() for p in params_with_grad)
v = torch.rand(num_param).to('cuda')
HVP = torch.autograd.grad(PW_grads, params_with_grad, v, retain_graph=True)
Both approaches ended up with
*** IndexError: pop from empty list.
The weird thing is there is no further explanation for this error. I wrapped some layers of my actor network with GradSampleModule from Opacus.
UPDATE. I found that removing the GradSampleModule wrapper solves the problem. However, I need GradSampleModule to get individual gradient samples.
The calculation of the loss is as follow
def compute_actor_loss(self, update_metrics = True, pathwise_grad = True):
rew_acc = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
gamma = torch.ones(self.num_envs, dtype=torch.float32, device=self.device)
if not pathwise_grad:
logprob_acc = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
with torch.no_grad():
if self.obs_rms is not None:
obs_rms = deepcopy(self.obs_rms)
# actor_loss = torch.zeros(self.num_envs, dtype=torch.float32, device=self.device)
# initialize trajectory to cut off gradients between episodes.
obs = self.env.initialize_trajectory()
obs = self._convert_obs(obs)
# breakpoint()
if self.obs_rms is not None:
# update obs rms
with torch.no_grad():
for k, v in obs.items():
self.obs_rms[k].update(v)
# normalize the current obs
obs = {k: obs_rms[k].normalize(v) for k, v in obs.items()}
# collect trajectories and compute actor loss
for i in range(self.horizon_len):
# take env step
if pathwise_grad:
actions = self.get_actions(obs, eval = False)
else:
actions, logprob = self.get_actions(obs, eval = False, pathwise = False)
obs, rew, done, extra_info = self.env.step(actions)
obs = self._convert_obs(obs)
with torch.no_grad():
raw_rew = rew.clone()
# scale the reward
rew = self.reward_shaper(rew)
if self.obs_rms is not None:
# update obs rms
with torch.no_grad():
for k, v in obs.items():
self.obs_rms[k].update(v)
# normalize the current obs
obs = {k: obs_rms[k].normalize(v) for k, v in obs.items()}
done_env_ids = done.nonzero(as_tuple=False).squeeze(-1)
not_done_env_ids = (done == 0).nonzero(as_tuple=False).squeeze(-1)
# compute actor loss
rew_acc[not_done_env_ids] = rew_acc[not_done_env_ids] + gamma[not_done_env_ids] * rew[not_done_env_ids]
if not pathwise_grad:
logprob_acc[not_done_env_ids] += logprob[not_done_env_ids].sum(dim = 1)
# compute gamma for next step
gamma = gamma * self.gamma
# clear up gamma and rew_acc for done envs
gamma[done_env_ids] = 1.0
if pathwise_grad:
actor_loss = -rew_acc
else:
_loss = -rew_acc
_logprob = logprob_acc
baseline = _loss.mean()
adv = (_loss - baseline).detach()
actor_loss = adv * _logprob
actor_loss /= self.horizon_len * self.num_envs
actor_loss = actor_loss.sum()
return actor_loss