Torch.no_grad vs requires_grad = false

I understand that using torch.no_grad has the benefit of saving memory and computation because the network won’t backpropagate gradients to layers before torch.no_grad.
So, is the following statement correct: If I have a network and only want to update the first and last layers, I cannot simply put all the middle layers inside the context manager torch.no_grad. Instead, I should set “requires_grad = False” in all the middle layers. However, due to the chain rule, the gradients in the middle layers still need to be computed in order to compute the gradient of the first layer. Therefore, the conclusion is that I will not save memory or computation when using “requires_grad = False” compared to using torch.no_grad.

Instead, I should set “requires_grad = False” in all the middle layers.

Yes this is the right approach

However, due to the chain rule, the gradients in the middle layers still need to be computed in order to compute the gradient of the first layer.

Correct, but not entirely. You will still need to compute gradients wrt the inputs so that the gradients flow to lower layers, but you won’t need to compute gradients wrt all the parameters that you’ve set requires_grad=False to, and that saves computation.

You will also save memory as a result of this. Because depending on which inputs you need to compute gradients, that will change which inputs you will need to save for backward. For example, if I have z = x * y. if I only care about dz/dy, I’d only need to save x.

1 Like

As a small addition: this code shows what @soulitzer described.

2 Likes