DDPG actor-critic with shared layer?

Thanks I’m implementing DDPG where the actor and critic have a shared module. I’m running into an issue and I was wondering if I could get some feedback. I have the following:

INPUT_DIM = 100
BOTTLENECK_DIMS = 10
class SharedModule(nn.Module):
  def __init__(self):
    self.shared = nn.Linear(INPUT_DIM, BOTTLENECK_DIMS)
  def forward(self, x):
    return self.shared(x)

class ActorCritic(nn.Module):
  
  def __init__(self, n_actions, shared: SharedModule):
    self.shared = shared
    self.n_actions = n_actions
    
    # Critic definition 
    self.action_value = nn.Linear(self.n_actions, BOTTLENECK_DIMS)
    self.q = nn.Linear(BOTTLENECK_DIMS, 1)

    # Actor Definition
    self.mu = nn.Linear(BOTTLENECK_DIMS, self.n_actions)

    self.optimizer = optim.Adam(self.parameters(), lr=self.lr)

  def forward(self, state, optional_action=None):
    if optional_action:
        return self._wo_action_fwd(state)
    return self._w_action_fwd(state, optional_action)
    
  def _wo_action_fwd(self, state):
    shared_output = self.shared(state)

    # Computing the actions
    mu_val = self.mu(F.relu(shared_output))
    actions = T.tanh(mu_val)

    # Computing the Q-vals
    action_value = F.relu(self.action_value(actions))
    state_action_value = self.q(
      F.relu(T.add(compressed, action_value))
    )
    return actions, state_action_value

  def _w_action_forward(self, state, action):
    shared_output = self.shared(state)
    action_value = F.relu(self.action_value(actions))
    state_action_value = self.q(
      F.relu(T.add(compressed, action_value))
    )
    return actions, state_action_value

My training process is then

shared_module = SharedModule()
actor_critic = ActorCritic(n_actions=3, shared_module)

shared_module = SharedModule()
T_actor_critic = ActorCritic(n_actions=3, shared_module)

s_batch, a_batch, r_batch, s_next_batch, d_batch = memory.sample(batch_size)

########################################
# Generate labels
########################################
# Get our critic target
_, y_critic = T_actor_critic(s_next_batch)
target = T.unsqueeze(
    r_batch + (gamma * d_batch * T.squeeze(y_critic)),
    dim=-1
)

########################################
# Critic Train
########################################
actor_critic.optimizer.zero_grad()
_, y_hat_critic = actor_critic(s_batch, a_batch)
critic_loss = F.mse_loss(target, y_hat_critic)
critic_loss.backward()
actor_critic.optimizer.step()

########################################
# Actor train
########################################
actor_critic.optimizer.zero_grad()
_, y_hat_policy = actor_critic(s_batch)
policy_loss = T.mean(-y_hat_policy)
policy_loss.backward()
actor_critic.optimizer.step()

Issues / doubts

  1. Looking at OpenAI DDPG formula discussion, I’ve done step 12 and step 13 correctly (as far as I can tell). However, I don’t know how to do step 14.

The issue is that although I can calculate the entire Q-value, I don’t know how to take the derivative only with regards to theta. How should I go about doing this? I tried using

  def _wo_action_fwd(self, state):
    shared_output = self.shared(state)

    # Computing the actions
    mu_val = self.mu(F.relu(shared_output))
    actions = T.tanh(mu_val)

    # Computing the Q-vals
    with T.no_grad():
      action_value = F.relu(self.action_value(actions))
      state_action_value = self.q(
        F.relu(T.add(compressed, action_value))
      )
    return actions, state_action_value

but unsurprisingly I got an error of has no gradients which is unsurprising since I dropped off the gradients.

  1. This is more of a DDPG question as opposed to a pytorch one, but is my translation of the algorithm correct? I do a step for the critic and then one for the actor? I’ve seen

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

  1. Is there a way to train it so that the shared module is stable? I imagine that being trained on two separate losses (I’m optimizing over 2 steps) might make convergence of that shared module wonky.

In response to #2, I found that line of code from examples/actor_critic.py at master · pytorch/examples · GitHub which I found from Actor Critic implementation problem

Since I’m new I can’t post more than 2 links in my post womp womp