Hello everyone,
I want to set dropout to zero during one forward pass and re-enable it for the next foward pass. Forboth that, I still want to use components like batch norms in their training mode (i.e. NOT use running stats like exp. moving average).
My current approach is to use model.eval()
. The problem with this is that this also changes other components like batch norms (see list here) to use running stats instead of batch information.
Are there any convenient ways to temporarily disable dropout without affecting the other components?
As a bonus: If possible it would be nice to still have the flag model.training
be set to False
when disabling dropout as I could use this in another if statement.
Cheers