Confused about "set_grad_enabled"

  1. model.train() and model.eval() change the behavior of some layers. E.g. nn.Dropout won’t drop anymore and nn.BatchNorm layers will use the running estimates instead of the batch statistics. The torch.set_grad_enabled line of code makes sure to clear the intermediate values for evaluation, which are needed to backpropagate during training, thus saving memory. It’s comparable to the with torch.no_grad() statement but takes a bool value.

  2. All new operations in the torch.set_grad_enabled(False) block won’t require gradients. However, the model parameters will still require gradients.

  3. The running_loss will be “de-averaged” by multiplying it with inputs.size(0). Therefore you should divide by the whole dataset length, not the number of batches.

9 Likes