Why we need torch.set_grad_enabled(False) here?

In the following tutorial, why do we need “with torch.set_grad_enabled(phase == ‘train’)” because we aren’t calling backward on “test” phase anyways so no gradients will be calculated.

https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#training-the-model

You are right that we aren’t calling backward on “test”.
However, during “test”, the model layers will create the computation graph for forward propagation, which can only be cleaned-up by a backward propagation. To avoid this, with torch.set_grad_enabled(False) informs all the layers to not to create any computational graph, because we do not wish to backpropagate for current computations.

3 Likes

great! thank you for the response. makes sense. So is it just to due to memory issues that clean up of computational graph is required? what would happen if we don’t clean the computational graph? any effect on next iteration gradient calculations?

I think, the gradient calculations of next iterations won’t be affected because every forward call would be working on their own set of data. But if the computational graphs are not cleaned up, you would eventually face the Out of memory error such as:

1 Like