In this PPO tutorial, the split_trajs
of the SyncDataCollector
is False
. However, I want to split the collected data in trajectories and learn from them. So if I set this argument to True
, data collectors split by orbit are returned, but they are zero-padded. I want to remove this zero padding of the training data.
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=True,
device=device,
)
# ...
for i, tensordict_data in enumerate(collector):
for td_id, td_trajectory in enumerate(tensordict_data):
mask = td_trajectory["collector", "mask"]
# Now I want to erase the zero padding that each tensor in the trajectory tensordict has based on the mask (each tensor has a different size and dimension)
for _ in range(num_epochs):
# We'll need an "advantage" signal to make PPO work.
# We re-compute it at each epoch as its value depends on the value
# network which is updated in the inner loop.
advantage_module(tensordict_data)
There are data masks in tensordict["collector", "mask"]
, but I don’t know how to apply these to the entire tensordict and remove the zero padding comprehensively. The shape and size of each tensordict is of course different, so simply applying torch.masked_select
is naturally an error. And I feel that a straightforward implementation would be very cumbersome. Any ideas would be appreciated.
Here is example structure of TensorDict in my environment:
TensorDict(
fields={
action: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
collector: TensorDict(
fields={
mask: Tensor(shape=torch.Size([14]), device=cuda:0, dtype=torch.bool, is_shared=True),
traj_ids: Tensor(shape=torch.Size([14]), device=cuda:0, dtype=torch.int64, is_shared=True)},
batch_size=torch.Size([14]),
device=cuda:0,
is_shared=True),
done: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
loc: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
observation: Tensor(shape=torch.Size([14, 11]), device=cuda:0, dtype=torch.float32, is_shared=True),
pixels: Tensor(shape=torch.Size([14, 3, 28, 28]), device=cuda:0, dtype=torch.float32, is_shared=True),
reward: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
step_count: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
terminated: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
truncated: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
batch_size=torch.Size([14]),
device=cuda:0,
is_shared=True),
observation: Tensor(shape=torch.Size([14, 11]), device=cuda:0, dtype=torch.float32, is_shared=True),
pixels: Tensor(shape=torch.Size([14, 3, 28, 28]), device=cuda:0, dtype=torch.float32, is_shared=True),
sample_log_prob: Tensor(shape=torch.Size([14]), device=cuda:0, dtype=torch.float32, is_shared=True),
scale: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
step_count: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
terminated: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
truncated: Tensor(shape=torch.Size([14, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
batch_size=torch.Size([14]),
device=cuda:0,
is_shared=True)
Thank you