What is the exact format of the input TensorDict for ClipPPOLoss's forward method?

I checked the official document for ClipPPOLoss.

About the forward(*tensordict: TensorDictBase* ) → TensorDictBase, this is the information about the parameter:

tensordict – an input tensordict with the values required to compute the loss.

The information is too vague, and I had a difficult time figuring out what the input tensor dict is supposed to look like (the keys it should have and their corresponding values).
It seems ClipPPOLoss uses GAE internally, and the input tensor dict for ClipPPOLoss should look like the input tensor dict for GAE.
Does anyone know the exact format of the input tensor dict for ClipPPOLoss ?

FYI, I made the following investigation notes:

import torch
from tensordict import TensorDict
from torch import nn
from tensordict.nn import TensorDictModule, InteractionType
from torch.distributions import Categorical
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE


def main():
    policy_network = nn.Linear(2, 4)
    policy_module = TensorDictModule(
        module=policy_network,
        in_keys=["states"],
        out_keys=["logits"]
    )
    actor = ProbabilisticActor(
        module=policy_module,
        in_keys=["logits"],
        # There must be "action" key for using `ClipPPOLoss`.
        out_keys=["action"],
        distribution_class=Categorical,
        default_interaction_type=InteractionType.MODE,
        # This must be true for using `ClipPPOLoss`.
        return_log_prob=True
    )
    value_network = nn.Linear(2, 1)
    value_operator = ValueOperator(
        module=value_network,
        in_keys=["states"],
        out_keys=["values"]
    )
    gae = GAE(
        gamma=0.98,
        lmbda=0.95,
        value_network=value_operator
    )
    gae.set_keys(
        # Must match with the out_keys of the `value_operator`.
        value="values"
    )

    # PPO
    loss_module = ClipPPOLoss(
        actor_network=actor,
        critic_network=value_operator
    )
    loss_module.set_keys(
        # Must match with the out_keys of the `value_operator`.
        value="values"
    )

    # Input tensor dicts
    current_tensor_dict = TensorDict({
        "states": torch.FloatTensor([[0, 1], [2, 3]])
    }, batch_size=2)
    next_tensor_dict = TensorDict({
        "states": torch.FloatTensor([[4, 5], [6, 7]])
    }, batch_size=2)
    # We can just call it without storing the return value
    # because the new data will be appended to the input tensor dict.
    actor(current_tensor_dict)
    actor(next_tensor_dict)

    current_tensor_dict["next"] = next_tensor_dict
    next_tensor_dict["reward"] = torch.FloatTensor([[1], [-1]])
    next_tensor_dict["done"] = torch.BoolTensor([[1], [1]])
    next_tensor_dict["terminated"] = torch.BoolTensor([[1], [1]])

    # The new data will be appended to the input tensor dict.
    gae(current_tensor_dict)

    # It prevents the following error: "RuntimeError: tensordict prev_log_prob requires grad."
    current_tensor_dict["sample_log_prob"] = current_tensor_dict["sample_log_prob"].detach()

    # Calculate the loss information.
    # To be honest, it was not clear about the exact format of the input tensor dict for `ClipPPOLoss`.
    # My suggestion is to rely on the mutated result of the initial input tensor dict as much as possible.
    # In this case, the initial input tensor has been mutated through the actor and gae.
    # Still, I encountered many issues and had to solve them through debugging :(
    loss_tensor_dict = loss_module(current_tensor_dict)
    print(f"loss_tensor_dict: {loss_tensor_dict}")

    # The shape will be "torch.Size([])" because it's a scalar value (still tensor type though).
    loss_critic = loss_tensor_dict["loss_critic"]
    loss_entropy = loss_tensor_dict["loss_entropy"]
    loss_objective = loss_tensor_dict["loss_objective"]
    loss = loss_critic + loss_entropy + loss_objective
    print(f"loss: {loss}")

main()

Here is the console output:

loss_tensor_dict: TensorDict(
    fields={
        ESS: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        clip_fraction: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
loss: -0.431304931640625

All losses (and tensordictmodules) have an in_keys attribute that tells you what they expect.
In the case of PPO, GAE will be computed if the “advantage” key is missing

I got a better understanding after spending more time on the investigation.

Here is the updated note that explains the input tensor dict for ClipPPOLoss:

import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, InteractionType
from torch.distributions import Categorical
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

seed = 3
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

def main():
    policy_network = nn.Linear(2, 4)
    policy_module = TensorDictModule(
        module=policy_network,
        in_keys=["observation"],
        out_keys=["logits"]
    )
    actor = ProbabilisticActor(
        module=policy_module,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=Categorical,
        default_interaction_type=InteractionType.MODE,
        # This must be true when using `ClipPPOLoss`.
        return_log_prob=True
    )
    value_network = nn.Linear(2, 1)
    value_operator = ValueOperator(
        module=value_network,
        in_keys=["observation"],
        out_keys=["value"]
    )
    gae = GAE(
        gamma=0.98,
        lmbda=0.95,
        value_network=value_operator
    )
    gae.set_keys(
        advantage="advantage",
        value_target="value_target",
        value="value"
    )

    # PPO
    loss_module = ClipPPOLoss(
        actor_network=actor,
        critic_network=value_operator
    )
    # The following keys are required in the input tensor dict for PPO.
    # On top of that, PPO calls the policy and value networks.
    # For that reason, we need to include the parameter keys for them.
    # In this case, they need the "observation" key.
    # Note that we don't set keys for them here because they are not part of "accepted keys" in PPO.
    # You may wonder why PPO still calls the policy and value networks even though their values are already included in the input tensor dict.
    # Apparently, that's how the surrogate objective function works in PPO.
    # It needs to compare the values from the old and current policy to calculate the loss.
    # The purpose is to avoid drastic policy changes and make the training process more stable.
    loss_module.set_keys(
        # Output from `GAE`.
        advantage="advantage",
        # Output from `GAE`.
        value_target="value_target",
        # Output from the value network.
        value="value",
        # Output from `ProbabilisticActor`.
        action="action",
        # Output from `ProbabilisticActor`.
        sample_log_prob="sample_log_prob",
    )

    # Initial input tensor dict.
    # It will be mutated as we apply the policy network, value network, and GAE.
    current_tensor_dict = TensorDict({
        "observation": torch.FloatTensor([[0, 1], [2, 3]])
    }, batch_size=[2])
    # This will append new keys to the input tensor dict to include the output from the actor.
    actor(current_tensor_dict)
    # It prevents the following error: "RuntimeError: tensordict prev_log_prob requires grad."
    current_tensor_dict["sample_log_prob"] = current_tensor_dict["sample_log_prob"].detach()

    # We let this represent the new observation after taking action.
    # We need this for `GAE`.
    next_tensor_dict = TensorDict({
        "observation": torch.FloatTensor([[4, 5], [6, 7]])
    }, batch_size=[2])

    # Connect the current and next tensor dict for `GAE`.
    current_tensor_dict["next"] = next_tensor_dict
    next_tensor_dict["reward"] = torch.FloatTensor([[1], [-1]])
    next_tensor_dict["done"] = torch.BoolTensor([[1], [1]])
    next_tensor_dict["terminated"] = torch.BoolTensor([[1], [1]])

    # This will append new keys to the input tensor dict to include the output from `GAE`.
    # Actually, we don't have to call this because PPO will internally use `GAE` if we don't call it.
    # By the way, we can use custom logic to calculate `advantage` and `value_target` and directly append them into the input tensor dict.
    # It can be useful knowledge, especially when we cannot use `GAE` due to its limitations.
    # For example, `GAE` does not support an LSTM-based value network.
    gae(current_tensor_dict)

    # This shows the mutated result.
    print(f"current_tensor_dict: {current_tensor_dict}")

    loss_tensor_dict = loss_module(current_tensor_dict)
    print(f"loss_tensor_dict: {loss_tensor_dict}")
    loss_critic = loss_tensor_dict["loss_critic"]
    loss_entropy = loss_tensor_dict["loss_entropy"]
    loss_objective = loss_tensor_dict["loss_objective"]
    loss = loss_critic + loss_entropy + loss_objective
    print(f"loss: {loss}")

main()

And here is the console output:

current_tensor_dict: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
        advantage: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        logits: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                value: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([2]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        value: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        value_target: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
loss_tensor_dict: TensorDict(
    fields={
        ESS: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        clip_fraction: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        kl_approx: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_critic: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_objective: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
loss: -0.5069577097892761