Should we split the trajectories prior to calculating the loss for a DQN?

Consider this minimal example where data is sampled from a collector according to a MLP policy:

from torchrl.collectors import SyncDataCollector
from torchrl.modules import MLP, QValueModule
from torchrl.collectors.utils import split_trajectories

mlp = Mod(
    MLP(
        out_features = env.action_spec.shape[-1],
        num_cells = [10],
        device = device
    ),
    in_keys = ["observation"],
    out_keys = ["action_value"],
).to(device)
qval = QValueModule(spec=env.action_spec)
test_policy = Seq(mlp, qval)
test_policy(env.reset(seed = 0))

collector = SyncDataCollector(
    env,
    test_policy,
    frames_per_batch=frames_per_batch,
    total_frames=-1
)
loss = DQNLoss(test_policy, action_space=env.action_spec, loss_function='smooth_l1')

data = next(iter(collector))
print("Without split", loss(data.to(device))["loss"])
print("With split", loss(split_trajectories(data).to(device))["loss"])

The outputs are:

Without split tensor(91.7286, device='cuda:0', grad_fn=<MeanBackward0>)
With split tensor(33.9880, device='cuda:0', grad_fn=<MeanBackward0>)

Which of these two cases should I use for training a model?

I don’t have access to the env and other params in this script but my take is that your second example (with split) has padded data as input and the loss is computed over all the data, padded and non-padded (I don’t think DQN checks for a mask key when being executed). If that’s a desired feature feel free to open an issue!

But TLDR the first (non-split) is the right one - no need to split the data, everything is designed to work with concatenations of trajectories.

1 Like