'model.eval()' vs 'with torch.no_grad()'

@albanD @ptrblck Thank you for your clear explanations! I have learned a lot from this question.

I tried both and I noticed that with torch.no_grad() in my evaluation loop it took ~5 sec more per epoch. I donā€™t get why.

By ā€œbothā€ do you mean model.eval() and torch.no_grad()?
If so, these calls are independent as explained in this previous post.

What is your baseline, that runs faster? Is it the training loop?
And how are you profiling the code? Note that CUDA operations are asynchronous, so that you would have to synchronize before starting and stopping the timer.

Iā€™m only speaking about my eval loop (validation loop), and Iā€™m comparing the eval loop with and without torch.no_grad(). In particular, with torch.no_grad() itā€™s a little bit slower (around 5 sec). I always use model.eval() before entering the loop. I have a class named RunManager(), which does exactly what the name implies, amongst many things it also keeps track for start/end times for each run (different set of parameters, like batch size, learning_rate etc.) and for each epoch. All my code runs in cuda.

Thatā€™s weird.
Could you post a code snippet to reproduce this issue?

I had a mistake in my code, model.train() was in the wrong line. Indeed, for the same run:

  • with torch.no_grad(): ~55 sec

  • without torch.no_grad(): ~57 sec

If we want to select the best model with minimum validation loss, why we need to set eval mode to compute the validation loss? Why not:
with torch.no_grad():
model.train()

Why we donā€™t want to consider the DP or BN when computing validation loss?

If i only use model.eval(), does it compromise the final performance (ignore the speed now)?

Hi Jangang,
the dropout layers and batch normalization layers behave differently in train and eval(test) procedure. Specifically, itā€™s a stochastic layer with e.g. dropout rate of p in train, and itā€™s deterministic in eval. Thus when doing the evaluation(dev test), {with torch.no_grad(): model.train()} is not equivalent to {with torch.no_grad(): model.eval()}.

Hi Yongkai,
yes it does, accroding to the discussions above. After calling model.eval(), layers like dropout or batch normalization are stochastic and so is the final performance.

Thanks for the reply. Yes, I know that the behavior of dropout and BN should be different in train and test. But I mean when we use a validation set to select the model with minimum validation loss on the validation set during the training. What the behavior of dropout layer or BN should be? Shall they be stochastic or deterministic?

tl;dr: They should be deterministic while evaluating

Letā€™s investigate it with the help of code:
In the method Module.eval() there is only one statement calling self.train with parameter False (torch.nn.modules.module ā€” PyTorch 2.1 documentation). This in turn sets the training flag of current and children modules to False. While in the dropout, for example, before calling the c++ code referred in the torch.functional, the state of the training flag is checked. Then according to the document of dropout layer (Dropout ā€” PyTorch master documentation), the dropout layer randomly zeroes elements during TRAINING only and thus behaves deterministically after calling eval().

You can verify this using this code snippet:

import torch
m = torch.nn.Dropout(0.1)
a = torch.randn(2, 3)
m(a) # run this line for several times and see the output difference
m.eval() # here the flag ā€œtrainingā€ is set to False
m(a) # also run this line for several times, no output changes
m.train() # this call sets the flag training = True and the Dropout layer behaves stochastically
m(a) # output with stochastic dropout

Hi, thanks for being so supportive in this community. I just have a little doubt.
Can we do validation like this?

with torch.no_grad:
for batch in val_loader:
#some code

i.e no model.eval()

Thanks

You could do it

  • if no layers are used, which would switch their behavior after calling model.eval() (e.g. dropout, batchnorm)
  • or if you want explicitly to use dropout during evaluation or update the running stats of batchnorm layers using the validation set (which would be considered a data leak in the ā€œstandardā€ use case).

Thanks a lot. I was calculating the validation loss after each epoch using the following code and was using the value of val loss to update the scheduler. Will this be considered as a data leakage?

def Solver_NN(model, train_loader, dev_loader, optim, criterion, device,  scheduler, print_every=10, epoch=51, lr=1e-1):
  print("Solver Initiated")  
  model=model.to(device) # sending model to GPU
  print("Model successfully sent to the GPU\n")

  print_every=print_every
  total_step = len(train_loader)
  counter=0

  for e in range(epoch):
    running_loss = 0.0    
    epoch_loss=0.0
    for i, (x,y) in enumerate(train_loader):
      optim.zero_grad()

      X=x      
      #print(X.shape)
      X=X.to(device=device, dtype=torch.float)
      y=y.to(device=device, dtype=torch.long)
     
      #forward pass########
      y_pred=model(X)                  
      loss=criterion(y_pred, y)          
      ####################
      
      # backward pass#######     
      loss.backward()
      optim.step()
      ####################      
      running_loss += loss.item()
      epoch_loss+=running_loss
      if (i+1) % print_every == 0:       # print every 10
        print ("Epoch [{}/{}], Step [{}/{}] Loss: {}".format(counter+1, epoch, i+1, total_step, running_loss/print_every))
        running_loss = 0.0
    counter+=1     

    with torch.no_grad():            
      train_hter, train_loss=HTER(model=model, loss_criterion=criterion, loader=train_loader)    
      dev_hter, dev_loss=HTER(model=model, loss_criterion=criterion, loader=dev_loader)           
      print(f"Train loss in epoch {e+1} is {(train_loss)} and Train HTER in epoch {e+1}: {train_hter}") 
      print(f"Dev loss in epoch {e+1} is {(dev_loss)} and Dev HTER in epoch {e+1}: {dev_hter}")  
      scheduler.step(dev_loss)    
    torch.save(model.state_dict(), sys_path + 'Codes/Replay Attack/weights_replay_attack_1FPS/weights1/Resnet_Replay_attack_No_LBP_'+ str(e+1) + "_" + str(np.floor(dev_hter))+'.pkl') 
    print("Model saved successfully!\n")
    
  return model

This is the HTER function used inside the with torch.no_grad() block:


def HTER(model, loss_criterion, loader):
  Ys=list()
  Y_preds=list()
  Loss_total=0    
  total_step=len(loader)
  
  for touple, label in loader:
    batch=touple
    batch=batch.to(device=device, dtype=torch.float)
    label=label.to(device=device, dtype=torch.long)
    Ys.append(label) 
    Y_pred=model(batch)       
    Y_preds.append(torch.argmax(Y_pred, dim=1))    
    Loss=loss_criterion(Y_pred,label)    
    Loss_total += Loss.item() 
    
    
  Y= torch.cat(Ys, dim=0)
  Y_pred=torch.cat(Y_preds, dim=0)
  tp, tn, fp, fn =confusion_matrix(Y, Y_pred)
  hter=1-(0.5*((tp/(tp+fn))+(tn/(tn+fp))))
  
  return (hter*100, Loss_total/total_step)

If your model uses batchnorm (or any other normalization layer, which updates the running stats), I would consider this a data leak, yes, since some information of the validation dataset was used to update the model.

1 Like

Yes, my model has some BatchNorm2D layers. Thank you so much, this has been really insightful. So, if I want to update the scheduler with val_loss as mentioned in the documentation of ReduceLROnPlateau. I have to do that set the model to model.eval() ?
https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau

Hi, I am new to pytorch.
Does it mean model.train() and model.eval() only affect layers like batchnorm and dropout? If my net doesnā€™t have batchnorm and dropout layers, does it mean I donā€™t need to call model.train() and eval()?

Even if you donā€™t use these two layers, I would nevertheless recommend to use model.train() and .eval() for the sake of clarity and in case some new or custom layers might switch their behavior.
E.g. the model itself or any custom module you are using might use the self.training flag inside the forward method, which would thus switch the behavior.

4 Likes

OK! Thanks for the explanation!