Basic Explanation of torch gradients - '.backward()' vs autograd

Hi,
I am trying to understand the way torch functions.
I have tried to search the docs, but though torch docs are very descriptive, I couldn’t find a concrete answer.
I will write this question differently, I will write what I have understood, and I will be glad if someone will correct me.

I understood that the way PyTorch and the autograd works is as follows:

  1. The computational graph is being built from the ground up in every .forward() pass.

  2. The .forward() pass is adding the .grad attribute to all the parameters with the .requires_grad attribute set to True (but not attaching any value to it).

  3. Then, when we use .backward(), we are backpropagating through the graph and adding the gradients to the .grad attributes of the parameters with the .requires_grad=True.

  4. If we won’t do .backward(), the ‘.grad’ will be None (even if the .requires_grad is equal to True).

  5. If we do multiple .backward() one after the other (even with no .forward() in between), the .grad attribute for every parameter will be accumulated.

  6. If we do .forward() pass with torch.no_grad(), we are building the graph with .grad of all the parameters, set to None. So .backward() will have no effect.

  7. If we do optimizer.zero_grad() we are setting to 0, the .grad attributes of all the parameters inside the optimizer (or inside the model, if we set the optimizer on all the model parameters).

  8. The optimizer.step() is changing the weights of the parameters based on their .grad attribute.

  9. I’m sure I have a mistake in few of these, maybe in all of them, but I want to understand why the autograd is being called auto… maybe it is calculating the gradients automatically when we do the .forward() pass? if so, why do we need the .backward()? we can do just the optimizer.step() and update the weights based on the gradients that has been calculated by the autograd…

I hope someone will clarify some of this for me…
Thank you,

2 Likes

As far as I understand, torch.autograd. backward() calculates the sum of gradients of the tensors w.r.t. graph leaves. And optimizer.step() calculates the new parameter(weights) values based on these gradients.

1 Like

Hi,

I think all of these are pretty correct yes.
Just keep in mind the distinction between autograd (that just runs the backward pass and compute gradients) and torch.nn (that is a set of utilities to work with neural nets).

  1. We actually build the graph under the hook. The only place where you can see it is from the .grad_fn field. The .grad attribute exists for every Tensor, it is just None by default.

  2. We accumulate in the .grad fields of all the leaf Tensors (t.is_leaf()) that require gradients yes. (Or the ones on which you called retain_grad() that can be used to get the .grad field populated for non-leaf Tensors).

  3. In no_grad mode, no graph is actually built, new Tensors will have requires_grad=False and grad_fn=None.

  4. Here we enter the torch.nn realm. Your description is correct

  5. Autograd is called auto from Automatic Differentiation. In comparison to Symbolic Differentiation or Numerical Differentiation that are other ways to compute gradients.

1 Like

Thank you for your answer,

  1. We actually build the graph under the hook. The only place where you can see it is from the .grad_fn field. The .grad attribute exists for every Tensor, it is just None by default.
  1. We accumulate in the .grad fields of all the leaf Tensors (t.is_leaf()) that require gradients yes. (Or the ones on which you called retain_grad() that can be used to get the .grad field populated for non-leaf Tensors).

So as I understand it, in every .forward() pass, we build the graph under the hook with the .grad_fn that connects every variable to the function and variables that creates it (this is done by the autograd).
When we do .backward(), we are calculating the .grad attribute for every leaf variables, from the .grad_fn attribute.
But, how we calculate the gradients for non-leaf variables? If we are not doing so, how does the weights for those variables change by the optimizer.step()?

  1. In no_grad mode, no graph is actually built, new Tensors will have requires_grad=False and grad_fn=None.

So, in no_grad mode, .backward() and optimizer.step() will have no effect?

Thank you,

But, how we calculate the gradients for non-leaf variables?

You don’t need it because all the nn.Parameter() are leafs. So .backward() will always compute their gradient. And so the optimizer will always have the .grad to work with.

So, in no_grad mode, .backward() and optimizer.step() will have no effect?

Since the Tensor have requires_grad=False, you cannot call .backward() on it (you will get an error).
You can still do optimizer.step() (remember this is a nn construct, unrelated to the autograd) and it will perform its update based on the current .grad (these can be None if no .backward() was called before).

Side note, for some optimizers (like Adam or SGD with momentum or SGD with weight decay), even with a gradient of 0, the optimizer step will still change the weights.

1 Like