How to use PPOLoss with shared actor and critic parameters?

I am training an actor-critic model in a reinforcement learning environment which processes observations in a shared hidden module before passing the data to the policy and value heads as shown in this diagram.

The PyTorch documentation for ClipPPOLoss indicates that you can train these kinds of networks in the following way:

If the actor and the value function share parameters, one can avoid calling the common module multiple times by passing only the head of the value network to the PPO loss module:

common = SomeModule(in_keys=["observation"], out_keys=["hidden"])
actor_head = SomeActor(in_keys=["hidden"])
value_head = SomeValue(in_keys=["hidden"])
# first option, with 2 calls on the common module
model = ActorCriticOperator(common, actor_head, value_head)
loss_module = PPOLoss(model.get_policy_operator(), model.get_value_operator())
# second option, with a single call to the common module
loss_module = PPOLoss(ProbabilisticTensorDictSequential(model, actor_head), value_head)

I am trying to implement the second option, using TorchRL’s Generalized Advantage Estimation as the loss_module.value_estimator. Using the model as defined above works fine for filling up a SyncDataCollector with frames, but I’m running into issues with calculating the advantage on each training iteration, after the batch has been sampled.

with torch.no_grad():
    GAE(
            tensordict_data,
            params=loss_module.critic_network_params,
            target_params=loss_module.target_critic_network_params,
        ) 

Even though the collected tensordict_data contains the hidden key, value_head is being fed this empty input:

(NonTensorData(
    data=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False),)

I’ve traced the error to GAE calling the _call_value_nets() function in the linked documentation above; this function is searching for the "hidden" key in the current frame (which is present) as well as in the nested ("next", "hidden") value. This produces the empty input because the "hidden" key has not been defined in "next". I tried to get around this by setting the “target parameters” to None, and shifted=True:

loss_module.make_value_estimator(
    ValueEstimators.GAE, gamma=gamma, lmbda=lmbda, shifted=True
)

GAE(
   tensordict_data,
   params=loss_module.critic_network_params,
   target_params=None,
   ) 

But this throws the error

 KeyError: "got keys {'hidden', 'state_value'} and set() which are incompatible 

Is there a way to define hidden in next, or is there something that I’m missing to implement PPO with shared actor/critic parameters?