'model.eval()' vs 'with torch.no_grad()'

hi, are you sure bn and dropout work in eval model? I think bn and dropout work in trainning mode, not working in validation and test mode.

Hi,

There is no such thing as “test mode”.
Only train() and eval().
Both bn and dropout will work in both cases but will have different behaviour as you expect them to have different behaviours during training and evaluation. For example, during evaluation, dropout should be disabled and so is replaced with a no op. Similarly, bn should use saved statistics instead of batch data and so that’s what it’s doing in eval mode.

3 Likes

You might want to modify your response as it can easily confuse readers. Your comment says “batchnorm or dropout layers will work in eval model instead of training mode.” I think you wanted to write eval mode, not eval model.

Thanks I edited the answer above.

i understood that

  • eval() changes the bn and dropout layer’s behaviour

  • torch.no_grad() deals with the autograd engine and stops it from calculating the gradients, which is the recommended way of doing validation

BUT, I didnt understand the use of with torch.set_grad_enabled()

Can you pls explain what is its use and where exactly can it be used.
Thanks ! :slight_smile:

torch.set_grad_enabled lets you enable or disable the gradient calculations using a bool argument.
Have a look at the docs for example usage.

but torch.no_grad() does the same thing. is there any difference between these two?

torch.no_grad just disables the gradient calculation, while torch.set_grad_enabled sets gradient calculation to on or off based on the passed argument.

are you saying that torch.no_grad and torch.set_grad_enabled(False) are the same ?

1 Like

Yes, if you are using it as a context manager. torch.set_grad_enabled can “globally” enable/disable the gradient computation, if you call it as a function.

1 Like

The method is called “inverted dropout”, whose purpose is to ensure the expectation of the dropout layer’s output remain unchanged.

Btw, if “inverted dropout” not applied (which mean you dont apply 1/(1-p)), the dropout layer’s output keep changing significantly (because it follows Bernoulli distribution and you never know how many nodes are dropped out this time), finally the output of whole network CANNOT keep stable which will disturb the procedure of backwardpropagating.

Another perspective is What is inverted dropout?

1 Like

Thanks for the awesome explanation, but I feel I’m missing one piece for the distinction. Why is it necessary to be able to backprop when doing model.eval()?

Hi,

it’s not “necessary” to be able to backprop when doing .eval(). It’s just that .eval() has nothing to do with the autograd engine and the backprop capabilities.

2 Likes

Why is model forward pass slow while using torch.nograd()

Hi,

I don’t see any mention to speed in this blogpost.
Can you detail your question a bit more please?

Hi @ptrblck, is it required to set gradient enabled with torch.set_grad_enabled(True) after torch.no_grad change back to model.train() from model.eval(), or the gradient will be automatically enabled with model.train(). I just want to confirm, it should be automatically enabled though.

model.train() and model.eval() do not change any behavior of the gradient calculations, but are used to set specific layers like dropout and batchnorm to evaluation mode (dropout won’t drop activations, batchnorm will use running estimates instead of batch statistics).

After the with torch.no_grad() block was executed, your gradient behavior will be the same as before entering the block.

2 Likes

Thanks for your explaination.
I am actually more interested in the usage of model.eval() and torch.no_grad()…

so means during evaluation, it’s enough to use:

model.eval()
for batch in val_loader:
    #some code

or I need to use them as:

model.eval()
with torch.no_grad():
    for batch in val_loader:
        #some code

Thanks

The first approach is enough to get valid results.
The second approach will additionally save some memory.

Thanks! That helps alot. :+1::+1: