Opinion: .eval() should be a context manager

 with mymodel.eval():
     eval_res = mymodel(test_data)
4 Likes

What is your opinion on this?
Do you run into errors or annoying situations without the context manager?

I my opinion the less context managers the better. Especially, if you are using PEP8! :wink:

Whike the torch.no_grad() context manager makes totally sense, since the gradient calculation was previously a property of the data (Variable(..., volatile=False)) and is now a “property” of the work flow, the eval property belongs to the model in my opinion, since it changes the behavior of some internal Modules. Using this context manager, it might also be more complicated to set some layers into eval, while others stay in train.

Would love to hear other opinions on this!

Switching correctly train()/eval() during training/validation is tricky in the presence of exceptions or other early exits. Once in the wrong state your training might be totally off. Context managers help in such situations. I’m using the following context manager:

# Following snippet is licensed under MIT license

from contextlib import contextmanager

@contextmanager
def evaluating(net):
    '''Temporarily switch to evaluation mode.'''
    istrain = net.training
    try:
        net.eval()
        yield net
    finally:
        if istrain:
            net.train()

Intended usage: with evaluating(net): .... Also in case you are concerned having too many nested blocks, you can collapse multiple context managers into a single line with evaluating(net), torch.no_grad(): in py3.x and later. In 2.7 you need to use nested from contextlib.

9 Likes

Clever, thank you for sharing!

Thank you @Christoph_Heindl . Dont suppose… could you explicitly grant a license to use the code, eg MIT, (or something else that is relatively free, ie not GPLv3 or similar).

@hughperkins modified above snippet to contain license info.

One thing that should be clarified by someone who has internal insights is if net.training is always a primitive boolean and not something complex. Otherwise remembering the training state before switching to eval would require an explicit copy.

1 Like

@hughperkins modified above snippet to contain license info.

Great! Thanks! :slight_smile:

With the current system, it’s also easy to make the error of calling model.eval() at the start of the validation loop, and model.train() at the end of the eval() loop, which can change model.training unexpectedly if the model was in eval mode when starting the validation loop. The correct way is to check the state of the model before eval mode and restore it after, and this context manager can help avoid this error.

Another thing: if eval() is a context manager, there should be a train context manager as well. The use case would be ensuring model state for a certain block of code, and restoring the model to its original state after finishing that block.

Something else to consider: the name of this context manager (and other, related context managers) should be carefully evaluated - for example, it would be easy to think that with torch.inference_mode() would put the model in eval() mode, when it doesn’t.

2 Likes

I agree with @jastern33, I landed here after being confused about with torch.inference_mode() not setting the model in eval mode.