Inplace parameter updation without torch.no_grad()

Hello All,

I have just started learning this awesome tool called PyTorch but sadly I am stuck in an equivocal situation.

Below is a code snippet from one of the tutorials

 with torch.no_grad():
       weights -= weights.grad * lr
       bias -= bias.grad * lr
       weights.grad.zero_()
       bias.grad.zero_()

I am kind of confused that even if I will do parameter update without using torch.no_grad() ( i.e. only in-place ) and since the backward call has already been made in the code above this snippet( not included in the snippet above ) which basically means all the “grad” attributes are already computed and don’t require the “original” values again. Then, why is it illegal to do those operations without torch.no_grad()?

I know it will flag off the error in PyTorch but I just wanted to know where my line of thought is at fault?

1 Like

Hi,

It is illegal because it would mean that the weight update would be done in a differentiable manner.
So that means that in the next iteration, when computing the gradients, it will flow back through multiple iterations of your training loop which is most likely not what you want to do !

Hey,

Will you please clear where am I wrong in the below thought-process?

Assume, I only have one trainable weight parameter to update. So, on calling loss.backward() the gradient of loss wrt that weight has been saved in the node corresponding to the tensor comprising of that weight.
The way I perceive in-place operation is:- "The new value of weight oust the previous value in the ‘same’ node i.e. according to me( unsure if it’s correct ), there is no change in the computational graph other than the “value” update or no new node/edge has been added.

Going by my perception, since except values nothing else has changed, it shouldn’t be illegal.

Please correct the parts where I am wrong.

1 Like

there is no change in the computational graph other than the “value” update

No there is actually a change. Because the autograd needs to take into account this change to compute the proper gradients.

Also this issue only happens when multiple iterations happen. For just one, it doesn’t change much indeed as you don’t reuse the graph.

Thank you so much for bolstering my perception of PyTorch.