Hi, In a batched Env, we have different trajectories. So at the end of step,for a batch_size of 3 we could have done = False,False,True. But in this case, the _reset function received a tensordict with only a ‘_reset’ key with a copy of done tensor. How to apply the “done” only to the 3th(the True one)? How to keep trace of others elements (the False ones)? In the pendulum tutorial at _step we have: done = torch.zeros_like(reward, dtype=torch.bool) because Env is non terminating but imagine if we want to have a done if angle is outside some limits? How to manage this in the _reset function?
The solution is:
def _reset(self, tensordict):
td = self.gen_params(batch_size=self.batch_size)
# in _reset function 1st Define a output tensordict (I named 'out')
# witch
# init all keys ... for example:
val1 = torch.zeros(td.shape, device = self.device)
val2 = torch.ones(td.shape, device = self.device)
out = TensorDict({"params": td["params"],
"val1": val1,
"val2": val2,
"action_mask":new_action_mask},
batch_size=td.shape,)
# Then if reset after a step we have only 'the reset' key in
# tensordict. It contain the done mask we must apply on 'out'
if tensordict is not None and '_reset' in tensordict.keys():
reset_mask = tensordict["_reset"].squeeze(-1)
for key, val in out.items():
if key == 'action_mask':
#in my case I have to unsqueeze 'action_mask'
tensordict.set(key,
torch.where(reset_mask.unsqueeze(-1),
val,
out.get(key)) )
else:
tensordict.set(key,
torch.where(reset_mask,
val,
out.get(key)) )
tensordict["_reset"] = torch.zeros_like(tensordict["_reset"])
else:
tensordict = out
return tensordict
# And torchrl make the job after ...
Hum, not sure that
tensordict["_reset"] = torch.zeros_like(tensordict["_reset"])
is a good idea… because I don’t understand well it’s role in _reset_proc_data and _reset_check_done (torchrl/envs/common.py EnvBase) !
And what’s about:
# allow_done_after_reset (bool, optional): if ``True``, an environment can
# be done after a call to :meth:`reset` is made. Defaults to ``False``.
and
# auto_reset (bool, optional): if ``True``, the env is assumed to reset
# automatically when done. Defaults to ``False``.
…?
In Torchrl’s batched envs (if this is what we are talking about) like Parallel or Serial, the “_reset” entry instructs step_and_maybe_reset as well as reset to just run over the elements that are marked as reset. The former method will make sure that everything is masked properly such that your data is returned to you as it should! To convince yourself that everything runs fine you can append a StepCounter transform and check that the step count is reset only for the done env.