What does the backward() function do?

(Soumyadeep Ghosh) #1

I have two networks, “net1” and "net2"
Let us say “loss1” and “loss2” represents the loss function of “net1” and “net2” classifier’s loss.
lets say “optimizer1” and “optimizer2” are the optimizers of both networks.

“net2” is a pretrained network and I want to backprop the (gradients of) the loss of “net2” into “net1”.
loss1=…some loss defined
So, loss1 = loss1 + loss2 (lets say that loss2 was defined initially)

So I do
loss1.backward(retain_graph=True) #what happens when I write this

What is the difference between backward() and step() ???

If I do not write loss1.backward() what will happen ??

(colesbury) #2

loss.backward() computes dloss/dx for every parameter x which has requires_grad=True. These are accumulated into x.grad for every parameter x. In pseudo-code:

x.grad += dloss/dx

optimizer.step updates the value of x using the gradient x.grad. For example, the SGD optimizer performs:

x += -lr * x.grad

optimizer.zero_grad() clears x.grad for every parameter x in the optimizer. It’s important to call this before loss.backward(), otherwise you’ll accumulate the gradients from multiple passes.

If you have multiple losses (loss1, loss2) you can sum them and then call backwards once:

loss3 = loss1 + loss2

(Soumyadeep Ghosh) #3

Hi @colesbury, thanks for your illustration.

I have one more question. Lets say I want to backprop “loss3” into “net1” and do not want to backprop “loss 2” to “net2”. In that case I should not write


I should only write

loss3 = loss1 + loss2
loss3.backward(). RIGHT ??

In case I have written loss2.backward() and have not written optimizer2.step(), will that affect my gradients when I compute loss3.backward(). ???

(Sarthak Bhagat) #4

Does backward update the weights, if we do not use optimizer?

(Simon Wang) #5

no it just computes gradients

(John1231983) #6

Very clear explaination!. If I have two losses:loss1 get data from dataloader1 with batch size of 4, while loss2 get data from loader2 with batch size of 4. Then what is batch size of loss=loss1+loss2? Is it 4 or 8? The code likes


(Aerin Kim) #7

No, it does not. Update happens only when you call step().