Feature Request: Consistent Dropout Implementation

I would like to put in a feature request to implement Consistent Dropout, as described in this paper:

Background

Reinforcement Learning typically does not use dropout layers as they cause instability.

What the above paper found, though, is fixing the dropout masks to be unchanged within each episode provided stable training and improved overall performance for RL networks vs. no dropout layers.

Implementation

One approach could be to supply an argument in nn.Dropout when instantiating called freeze which defaults to False.

If freeze = True, the masks stay fixed on those dropout layers until a manual model.dropout_reset() method is called.

Thank you.

Thanks for raising this.
Indeed it would be a pretty neat feature to have.
I think this would nicely fit in TorchRL (rather than torch core) as my understanding is that it is quite RL specific.
Wdyt?

@vmoens While I think the paper demonstrates itā€™s usefulness in RL, I donā€™t see why it wouldnā€™t be a potential improvement in other time sequence prediction networks, such as RNNs, for the same reasons.

The reason I was suggesting TorchRL is that we make sure that this sort of traj-consistent modules are easy to code and well standardized / tested / robust and that the proper errors / warnings are raised whenever you use them.

If you think this should be a pytorch core feature under torch.nn, feel free to open an issue on pytorch github. Given the scope of the feature, my 2 cents is that the answer will be that this is more suited for a domain library.

1 Like

We merged this in torchrl.
You can use the layer independently or within a torchrl script, hereā€™s an example:

            >>> from torchrl.modules import ConsistentDropoutModule
            >>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
            >>> from torchrl.envs import GymEnv, StepCounter, SerialEnv
            >>> m = Seq(
            ...     Mod(torch.nn.Linear(7, 4), in_keys=["observation"], out_keys=["intermediate"]),
            ...     ConsistentDropoutModule(
            ...         p=0.5,
            ...         input_shape=(2, 4),
            ...         in_keys="intermediate",
            ...     ),
            ...     Mod(torch.nn.Linear(4, 7), in_keys=["intermediate"], out_keys=["action"]),
            ... )
            >>> primer = get_primers_from_module(m)
            >>> env0 = GymEnv("Pendulum-v1").append_transform(StepCounter(5))
            >>> env1 = GymEnv("Pendulum-v1").append_transform(StepCounter(6))
            >>> env = SerialEnv(2, [lambda env=env0: env, lambda env=env1: env])
            >>> env = env.append_transform(primer)
            >>> r = env.rollout(10, m, break_when_any_done=False)
            >>> mask = [k for k in r.keys() if k.startswith("mask")][0]
            >>> assert (r[mask][0, :5] != r[mask][0, 5:6]).any()
            >>> assert (r[mask][0, :4] == r[mask][0, 4:5]).all()