Best way to temporarily disable dropout

Hello everyone,

I want to set dropout to zero during one forward pass and re-enable it for the next foward pass. Forboth that, I still want to use components like batch norms in their training mode (i.e. NOT use running stats like exp. moving average).

My current approach is to use model.eval(). The problem with this is that this also changes other components like batch norms (see list here) to use running stats instead of batch information.

Are there any convenient ways to temporarily disable dropout without affecting the other components?

As a bonus: If possible it would be nice to still have the flag be set to False when disabling dropout as I could use this in another if statement.


You could call .eval() on the dropout layer only or change its internal .training attribute by directly accessing it in the model.

Thanks, I do the following for now:

dropout_modules = [module for module in model.modules() if isinstance(module,torch.nn.Dropout)]
[module.eval() for module in dropout_modules] # disable dropout
[module.train() for module in dropout_modules] # enable dropout

If I run into any further problems I will come back to this.