'model.eval()' vs 'with torch.no_grad()'

You don’t need to compute grad when evaluate on valid/test set. Just a simple
torch.no_grad() will save you few second per epoch. It worth a while if you tuning your model.

Easiest way to optimize your workflow. Take it.

@ptrblck Hi, I wonder that if I have a pre-trained model to be used as a supervisor (e.g. a discriminator) which will be ensembled in a new model, and I want to zero dropout for this supervisor while keep dropout for the rest of my new model, what should I do?

I know that I can manually bypass the nn.Dropout layers, but what about the internal dropout in other models, like multi-layer RNN or Transformer? If I prefer to use torch.load to load the supervisor instead of state_dict, is there a convenient way to disable dropout for this part of my model?

I have tried that first switch to train mode for the whole model and then set the supervisor to eval mode alone, but it gave me

RuntimeError: cudnn RNN backward can only be called in training mode

I need the gradient :pleading_face:

This approach sounds correct.
Unfortunately, you are running into this cudnn issue, so after calling eval() on the discriminator, you would have to call train() again on the RNN.
Here is a pseudo code:

model = MyModel()

model.train() # not necessary, since the model should be in training mode by default

model.discriminator.eval() # disable dropout

model.discriminator.rnn.train() # enable RNN training mode for cudnn issue

Thank you! The code seems to work now, but I still has a question that will

makes model.discriminator.rnn keep its dropout?

Yes, that would be the case.
You could either set the .dropout attribute to zero or disable cudnn for this layer (then you wouldn’t need to call train() on it, but might see a slowdown):

rnn = nn.RNN(10, 20, 2, dropout=0.5).cuda()
rnn.dropout = 0.0
input = torch.randn(5, 3, 10).cuda()
h0 = torch.randn(2, 3, 20).cuda()
for _ in range(2):
    output, hn = rnn(input, h0)
    output.mean().backward()
    print(output.mean())


rnn = nn.RNN(10, 20, 2, dropout=0.5).cuda().eval()
input = torch.randn(5, 3, 10).cuda()
h0 = torch.randn(2, 3, 20).cuda()
for _ in range(2):
    with torch.backends.cudnn.flags(enabled=False):
        output, hn = rnn(input, h0)
    output.mean().backward()
    print(output.mean())

Oh, I see. I think

would be the best choice. Thank you so much!

@ptrblck Hi,

I have a question if a model gets “overfitting” with the training set. If I use that model to predict the sample in training set, the output should perform well with both model.train() and model.eval() or only work well in model.train() where it fits with the mean and variance of each batch only.

Thank you.

As so often it depends on your use case.
E.g. if your training set is shuffled in a way such that the batch statistics are different in the first iterations and the last one, the batchnorm stats would get a bias towards the latter stats.
While this wouldn’t be visible during training, you could see an increased loss after calling model.eval() on the first part of the training dataset.
Shuffling should avoid such scenarios, but it’s also not hard to create these artificial “edge” cases.

That being said, your model should usually perform similar during train() and eval() on the training dataset.

@ptrblck Wonderful, I never think about the effectiveness of shuffle training data in that way.

So if we shuffle data, we need a batch size big enough to handle enough information for the best mean and variance for the running stat, right?

If I have limited memory in GPU, I can keep the performance with accumulate gradient for update weights and batch norm synchronization (as this comment). My question is what should I do with batch norm if I only have 1 GPU?

Other options I can find is Group Norm, but the pretrained models with group norm are not so popular.

Using e.g. GroupNorm would be an alternative to BatchNorm layers if your batch size is small.
If that’s not possible due to pretrained models, you could also change the momentum term of the batchnorm layers to smooth the updates more.

@ptrblck Thank you for your advice,

I saw in Pytorch docs, the momentum in Pytorch used like this

running = (1-momentum) * running + momentum * present_variable

The default momentum is 0.1 so if I want to make it smoother, I should reduce it, right?

Yes, you could reduce it to lower the influence of the current batch statistic in the update.

1 Like

Since [quote=“albanD, post:33, topic:19615”]
.eval() has nothing to do with the autograd engine and the backprop capabilities.
[/quote]
Why keep backprop capabilities? Why not disable backprop to make .eval() faster? I think it’s a confusing design.

The gradient calculation and backpropagation are independent from the concept of changing the behavior of layers during training and validation.
E.g. you might want to disable the gradient calculation, but still use dropout to provide noisy outputs using the same input sample in order to estimate the classification robustness etc.
Binding no_grad() to eval() would unnecessarily limit these use cases.

  • model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval mode instead of training mode.

we use eval in testing mode. So why in the above statement it is saying batchnorm or dropout layers will work in eval, it should not work in eval mode. it should work in training mode.

Could you elaborate on why someone would need backdrop in an evaluation script?

but you won’t be able to backprop (which you don’t want in an eval script).

@michaelklachko @itsnamgyu @MadeUpMasters

There are a lot of methods in deep-learning that require the computation of gradients, such as adversarial attacks, searching latent code for GANs, BADGE active learning, class activation map, … . However, we don’t want to modify the internal parameters of batchnorm whenever we do these type of inference.

1 Like

Hi @ptrblck Could you help me confirm that if I use torch.no_grad() within model.train(), the batch norm layers and dropout layers will keep updating without gradient calculation if I pass data into the model, while other layers stay fixed?

torch.no_grad() will not store the intermediate activations and thus your gradient calculation will fail. However, optimizer.step() could still update the parameters, even with a zero gradient, if it’s using running stats internally.

model.train() will allow updates to the running stats of batchnorm layers for each forward pass. Dropout layers will be used but are not trainable.

1 Like

Thanks for the explanation!