Why do we need to set the gradients manually to zero in pytorch?


(jdhao) #21

In my use case, I am doing image retrieval using siamese network with 2 branches, so a dataset sample contains two images and a label indicating whether they are similar or not.

I do not want to change the image aspect ratio, so random crop the image to same size is not a valid choice. As a result, the batchsize is actually 1. Each time we process one image pair, accumulate the loss, when the input image pair reaches the real batchsize, we back propagate the accumulated loss.

In case 2, each time a single loss is calculated, the loss(should be divided by the real batchsize) is immediately back-propagated, then the graph is freed, which is more memory efficient. I think the result of case 2 and case 3 should be the same. But in case 2, since we back-propagate many more times, the training speed is a lot slower (I have done some test to find that).

I would prefer case 3 for its faster training speed. But we need to be careful to choose the real batchsize in order not to blow up the memory.


(jdhao) #22

Follow up. First I try to accumulate 64 single loss, then do one backward, but without success (GPU out of memory). When I reduce the number of accumulated loss to 16, it works. So right now, the real batch size is 64, but I do backward for every 16 samples (4 backward for the whole batch).


#23

Thanks a lot… I can understand it clearly now


(Hima) #24

Can you explain why #3 uses more memory than #2?
Why does calling loss.backward less often cause it to use more memory?


(Alban D) #25

#3 uses more memory because you need to store the intermediary results for 10 forwards to be able to do the backpropagation. In #2 you never have more than the intermediary results for 1 forward.


(Hima) #26

That makes sense.

Also, you wrote

# current graph is appended to existing graph
loss = loss + current_loss

I thought the loss would just be a scalar? But is it actually the entire graph?


(Alban D) #27

loss here is a Variable containing a single element, and it has associated to it, all the history of the computations that were made to be able to backpropagate.


#28

Where is this history stored exactly? It seems like it’s stored outside the variable. Let’s say I create two loss functions like so:

B = 8
linear = nn.Linear(5, 1)
x = Variable(torch.ones(B, 5))
y = linear(x)
loss_1 = 10 - y.sum()
loss_2 = 5 - y.sum()

Now as soon as I backpropagate loss_1, buffers are cleared.

loss_1.backward()

Backpropagating on loss_2 will give an error now:

loss_2.backward()  #gives an error

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

So it seems like the history is stored outside loss_1 and loss_2 both and is not deleted after calling backward() if retain_graph is True.

EDIT: Is it correct to assume that a new graph is created at the step y = linear(x)? In this case, can it be presumed that those buffers (or history) reside(s) in y and is referred to by subsequent Variables like loss_1 and loss_2?