Why do we need to set the gradients manually to zero in pytorch?

(like, it seems like we could have an option in backward to not zero out the gradients, like backward(preserve_grads=True), but by default, seems like zeroing out the gradients could be the default action)

9 Likes

Hi,

I think the big difference with tensorflow is the following.
Since you use a static graph, you define exactly what should be done to make one gradient computation/update. And then you just tell it to do it using a given input/target.

In pytorch, it is significantly more flexible as the autograd engine will just “remember” how to compute the gradient for a given variable while you are performing computations with this Variable. This means that you can get the gradients wrt a variable, then perform computation with it again, then recompute gradients corresponding to these new operations.
In this scheme, there is a not a single point where you stop performing “forward” operations and you know that the only thing that is left to be done is compute the gradients. So it is trickier to automatically set the gradients to 0 because you don’t know when a computation end, and when a new starts.

An example where the gradient accumulation is useful is for example if you share some part of a network for two different tasks:

input = Variable(data)
# Get the features
features = feature_extractor(input)

# Compute first loss and get the gradients for it
loss1 = task1(features)
loss1.backward(retain_graph=True)
# This add the gradients wrt loss1 in both the "task1" net and the "feature_extractor" net
# So each parameter "w" in "feature_extractor" has it gradient d(loss1)/dw

# Perform the second task and get the gradients for it as well
loss2 = task2(features)
loss2.backward()
# This will add gradients in "task2" and accumulate in "feature_extractor"
# Now each parameter in "feature_extractor" contains d(loss1)/dw + d(loss2)/dw

So the fact that the gradients are accumulated allows you to get the correct gradient for all the computations that you do with a given Variable even if you use it at multiple places in convoluted ways.
The drawback here is that you have to manually reset the values to 0 so that the gradients computed previously do not interfere with the ones you are currently computing.

82 Likes

thats interesting. I would have thought that the “sharing” would have been multiplicative since RNNs a composition of functions and not a mere addition (as in ur transfer learning example). I think I am more confused. :frowning:

The addition comes from the rules of differentiation. Iif f = f1 + f2, then the gradients for a parameter in both branches is the sum of the contributions of each branch.

1 Like

y.backward() doesn’t just assign the value of y’(x) to x.grad (say y depends on x). It actually adds y’(x) to the current value of x.grad (think it as x.grad += true_gradient).

In the following example, y.backward() is called 5 times, so the final value of x.grad will be 5*cos(0)=5.

import torch
from torch.autograd import Variable

x = Variable(torch.Tensor([[0]]), requires_grad=True)

for t in range(5):
    y = x.sin() 
    y.backward()
    
print(x.grad) # shows 5

Calling x.grad.data.zero_() before y.backward() can make sure x.grad is exactly the same as current y’(x), not a sum of y’(x) in all previous iterations.

x = Variable(torch.Tensor([[0]]), requires_grad=True) 

for t in range(5):
    if x.grad is not None:
        x.grad.data.zero_()
    y = x.sin() 
    y.backward()

print(x.grad) # shows 1

I also got confused by this “zeroing gradient” when first learning pytorch. The doc of torch.autograd.backward does mention that

This function accumulates gradients in the leaves - you might need to zero them before calling it.

But this is quite hard to find and pretty confusing for (say) tensorflow users.

Official tutorials like 60 Minute Blitz or PyTorch with Examples both say nothing about why one needs to call grad.data.zero_() during training. I think it would be useful to explain this a little more in beginner-level tutorials. RNN is a good example for why accumulating gradient (instead of refreshing) is useful, but I guess new users wouldn’t even know that backward() is accumulating gradient :sweat_smile:

49 Likes

@tom, since it is possible to accumulate the loss of several minibatch and do one parameter update. For example I want to update the parameter every 64 minibatch, I have the following code

total_loss = Variable(torch.zeros(1), requires_grad=True)

for idx, (data, target) in train_loader:

    data, target = Variable(data), Variable(target)
    output = model(data)
    loss = criterion(output, target)
    total_loss = total_loss + loss

    if (idx+1)%64 == 0:
        total_loss = total_loss/(64*batchsize)
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss = Variable(torch.zeros(1), requires_grad=True)

Is the above code correct to achieve the desired effect?

3 Likes

I’d do the backward on the (reweighted) loss in each run and not do total loss.

Best regards

Thomas

1 Like

I find that doing the backward for each reweighted loss is much slower than accumulate the loss and then backward.

2 Likes

Thanks for sharing this insight. I would have thought that you get similar speed at much less memory with separate backward calls. But clearly, the experiment proves my modest intuition to be wrong.

Best regards

Thomas

2 Likes

@albanD, I have some doubt about what the computation graph looks like in the case that we accumulate the loss manually.

Assume there are total 256 batches for the dataset. For the normal case (case1), for each batch in an epoch, a new computation graph is created and after the backward pass, the graph is freed. So 256 computation graphs are created and freed during one epoch.

In this case(case2), since we only do backward on every 64 batches. Does that mean only 4 graphs are created? The graph are created by composing 64 smaller graphs in case1 and the root node in the bigger graph is total_loss. The 64 smaller graph all have the same set of learnable parameters. If that is the case, the bigger graph will consume a lot of memory since it have 64 copies of the small graph.

Is that right? Do you have any ideas?

1 Like

Hi,

Indeed, in one case, you will create 256 graphs that work with one input.
In the second case, you will create only 4 graphs. but each of these 4 graphs is actually composed of 64 times the graph above and some Add operations at the end that sum the loss.

Indeed, in the second case you will use much more memory. Indeed, for the 64 iterations, you will create a single graph that just keep growing, and so you will use more and more memory.

6 Likes

So We have to make sure that batchsize is not too large, or we will run out of memory.

1 Like

Here are three equivalent code, with different runtime/memory comsumption.
Assume that you want to run sgd with a batch size of 100.
(I didn’t run the code below there might be some typos, sorry in advance)

1: single batch of 100 (least runtime, more memory)

# some code
# Initialize dataset with batch size 100
for input, target in dataset:
    pred = net(input)
    loss = crit(pred, target)
    # one graph is created here
    opt.zero_grad()
    loss.backward()
    # graph is cleared here
    opt.step()

2: multiple small batches of 10 (more runtime, least memory)

# some code
# Initialize dataset with batch size 10
opt.zero_grad()
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    loss = crit(pred, target)
    # one graph is created here
    loss.backward()
    # graph is cleared here
    if (i+1)%10 == 0:
        # every 10 iterations of batches of size 10
        opt.step()
        opt.zero_grad()

3: accumulate loss for multiple batches (more runtime, more memory)

# some code
# Initialize dataset with batch size 10
loss = 0
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    current_loss = crit(pred, target)
    # current graph is appended to existing graph
    loss = loss + current_loss
    if (i+1)%10 == 0:
        # every 10 iterations of batches of size 10
        opt.zero_grad()
        loss.backward()
        # huge graph is cleared here
        opt.step()

It should be clear that case 3 is not what you want.
The choice between case 1 and 2 is a tradeoff between memory and speed so that depends on what you want to do.
Note that if you can fit a batch size of 50 in your memory, you can do a variation of case 2 with batch size of 50 and update every 2 iterations.

91 Likes

In my use case, I am doing image retrieval using siamese network with 2 branches, so a dataset sample contains two images and a label indicating whether they are similar or not.

I do not want to change the image aspect ratio, so random crop the image to same size is not a valid choice. As a result, the batchsize is actually 1. Each time we process one image pair, accumulate the loss, when the input image pair reaches the real batchsize, we back propagate the accumulated loss.

In case 2, each time a single loss is calculated, the loss(should be divided by the real batchsize) is immediately back-propagated, then the graph is freed, which is more memory efficient. I think the result of case 2 and case 3 should be the same. But in case 2, since we back-propagate many more times, the training speed is a lot slower (I have done some test to find that).

I would prefer case 3 for its faster training speed. But we need to be careful to choose the real batchsize in order not to blow up the memory.

1 Like

Follow up. First I try to accumulate 64 single loss, then do one backward, but without success (GPU out of memory). When I reduce the number of accumulated loss to 16, it works. So right now, the real batch size is 64, but I do backward for every 16 samples (4 backward for the whole batch).

Thanks a lot… I can understand it clearly now

Can you explain why #3 uses more memory than #2?
Why does calling loss.backward less often cause it to use more memory?

#3 uses more memory because you need to store the intermediary results for 10 forwards to be able to do the backpropagation. In #2 you never have more than the intermediary results for 1 forward.

6 Likes

That makes sense.

Also, you wrote

# current graph is appended to existing graph
loss = loss + current_loss

I thought the loss would just be a scalar? But is it actually the entire graph?

loss here is a Variable containing a single element, and it has associated to it, all the history of the computations that were made to be able to backpropagate.

2 Likes