What is the exact format of the input TensorDict for ClipPPOLoss's forward method?

I checked the official document for ClipPPOLoss.

About the forward(*tensordict: TensorDictBase* ) → TensorDictBase, this is the information about the parameter:

tensordict – an input tensordict with the values required to compute the loss.

The information is too vague, and I had a difficult time figuring out what the input tensor dict is supposed to look like (the keys it should have and their corresponding values).
It seems ClipPPOLoss uses GAE internally, and the input tensor dict for ClipPPOLoss should look like the input tensor dict for GAE.
Does anyone know the exact format of the input tensor dict for ClipPPOLoss ?

FYI, I made the following investigation notes:

import torch
from tensordict import TensorDict
from torch import nn
from tensordict.nn import TensorDictModule, InteractionType
from torch.distributions import Categorical
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE


def main():
    policy_network = nn.Linear(2, 4)
    policy_module = TensorDictModule(
        module=policy_network,
        in_keys=["states"],
        out_keys=["logits"]
    )
    actor = ProbabilisticActor(
        module=policy_module,
        in_keys=["logits"],
        # There must be "action" key for using `ClipPPOLoss`.
        out_keys=["action"],
        distribution_class=Categorical,
        default_interaction_type=InteractionType.MODE,
        # This must be true for using `ClipPPOLoss`.
        return_log_prob=True
    )
    value_network = nn.Linear(2, 1)
    value_operator = ValueOperator(
        module=value_network,
        in_keys=["states"],
        out_keys=["values"]
    )
    gae = GAE(
        gamma=0.98,
        lmbda=0.95,
        value_network=value_operator
    )
    gae.set_keys(
        # Must match with the out_keys of the `value_operator`.
        value="values"
    )

    # PPO
    loss_module = ClipPPOLoss(
        actor_network=actor,
        critic_network=value_operator
    )
    loss_module.set_keys(
        # Must match with the out_keys of the `value_operator`.
        value="values"
    )

    # Input tensor dicts
    current_tensor_dict = TensorDict({
        "states": torch.FloatTensor([[0, 1], [2, 3]])
    }, batch_size=2)
    next_tensor_dict = TensorDict({
        "states": torch.FloatTensor([[4, 5], [6, 7]])
    }, batch_size=2)
    # We can just call it without storing the return value
    # because the new data will be appended to the input tensor dict.
    actor(current_tensor_dict)
    actor(next_tensor_dict)

    current_tensor_dict["next"] = next_tensor_dict
    next_tensor_dict["reward"] = torch.FloatTensor([[1], [-1]])
    next_tensor_dict["done"] = torch.BoolTensor([[1], [1]])
    next_tensor_dict["terminated"] = torch.BoolTensor([[1], [1]])

    # The new data will be appended to the input tensor dict.
    gae(current_tensor_dict)

    # It prevents the following error: "RuntimeError: tensordict prev_log_prob requires grad."
    current_tensor_dict["sample_log_prob"] = current_tensor_dict["sample_log_prob"].detach()

    # Calculate the loss information.
    # To be honest, it was not clear about the exact format of the input tensor dict for `ClipPPOLoss`.
    # My suggestion is to rely on the mutated result of the initial input tensor dict as much as possible.
    # In this case, the initial input tensor has been mutated through the actor and gae.
    # Still, I encountered many issues and had to solve them through debugging :(
    loss_tensor_dict = loss_module(current_tensor_dict)
    print(f"loss_tensor_dict: {loss_tensor_dict}")

    # The shape will be "torch.Size([])" because it's a scalar value (still tensor type though).
    loss_critic = loss_tensor_dict["loss_critic"]
    loss_entropy = loss_tensor_dict["loss_entropy"]
    loss_objective = loss_tensor_dict["loss_objective"]
    loss = loss_critic + loss_entropy + loss_objective
    print(f"loss: {loss}")

main()

Here is the console output:

loss_tensor_dict: TensorDict(
    fields={
        ESS: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        clip_fraction: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
loss: -0.431304931640625

All losses (and tensordictmodules) have an in_keys attribute that tells you what they expect.
In the case of PPO, GAE will be computed if the “advantage” key is missing