Why do we need to set the gradients manually to zero in pytorch?

@albanD, I have some doubt about what the computation graph looks like in the case that we accumulate the loss manually.

Assume there are total 256 batches for the dataset. For the normal case (case1), for each batch in an epoch, a new computation graph is created and after the backward pass, the graph is freed. So 256 computation graphs are created and freed during one epoch.

In this case(case2), since we only do backward on every 64 batches. Does that mean only 4 graphs are created? The graph are created by composing 64 smaller graphs in case1 and the root node in the bigger graph is total_loss. The 64 smaller graph all have the same set of learnable parameters. If that is the case, the bigger graph will consume a lot of memory since it have 64 copies of the small graph.

Is that right? Do you have any ideas?

1 Like

Hi,

Indeed, in one case, you will create 256 graphs that work with one input.
In the second case, you will create only 4 graphs. but each of these 4 graphs is actually composed of 64 times the graph above and some Add operations at the end that sum the loss.

Indeed, in the second case you will use much more memory. Indeed, for the 64 iterations, you will create a single graph that just keep growing, and so you will use more and more memory.

6 Likes

So We have to make sure that batchsize is not too large, or we will run out of memory.

1 Like

Here are three equivalent code, with different runtime/memory comsumption.
Assume that you want to run sgd with a batch size of 100.
(I didn’t run the code below there might be some typos, sorry in advance)

1: single batch of 100 (least runtime, more memory)

# some code
# Initialize dataset with batch size 100
for input, target in dataset:
    pred = net(input)
    loss = crit(pred, target)
    # one graph is created here
    opt.zero_grad()
    loss.backward()
    # graph is cleared here
    opt.step()

2: multiple small batches of 10 (more runtime, least memory)

# some code
# Initialize dataset with batch size 10
opt.zero_grad()
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    loss = crit(pred, target)
    # one graph is created here
    loss.backward()
    # graph is cleared here
    if (i+1)%10 == 0:
        # every 10 iterations of batches of size 10
        opt.step()
        opt.zero_grad()

3: accumulate loss for multiple batches (more runtime, more memory)

# some code
# Initialize dataset with batch size 10
loss = 0
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    current_loss = crit(pred, target)
    # current graph is appended to existing graph
    loss = loss + current_loss
    if (i+1)%10 == 0:
        # every 10 iterations of batches of size 10
        opt.zero_grad()
        loss.backward()
        # huge graph is cleared here
        opt.step()

It should be clear that case 3 is not what you want.
The choice between case 1 and 2 is a tradeoff between memory and speed so that depends on what you want to do.
Note that if you can fit a batch size of 50 in your memory, you can do a variation of case 2 with batch size of 50 and update every 2 iterations.

91 Likes

In my use case, I am doing image retrieval using siamese network with 2 branches, so a dataset sample contains two images and a label indicating whether they are similar or not.

I do not want to change the image aspect ratio, so random crop the image to same size is not a valid choice. As a result, the batchsize is actually 1. Each time we process one image pair, accumulate the loss, when the input image pair reaches the real batchsize, we back propagate the accumulated loss.

In case 2, each time a single loss is calculated, the loss(should be divided by the real batchsize) is immediately back-propagated, then the graph is freed, which is more memory efficient. I think the result of case 2 and case 3 should be the same. But in case 2, since we back-propagate many more times, the training speed is a lot slower (I have done some test to find that).

I would prefer case 3 for its faster training speed. But we need to be careful to choose the real batchsize in order not to blow up the memory.

1 Like

Follow up. First I try to accumulate 64 single loss, then do one backward, but without success (GPU out of memory). When I reduce the number of accumulated loss to 16, it works. So right now, the real batch size is 64, but I do backward for every 16 samples (4 backward for the whole batch).

Thanks a lot… I can understand it clearly now

Can you explain why #3 uses more memory than #2?
Why does calling loss.backward less often cause it to use more memory?

#3 uses more memory because you need to store the intermediary results for 10 forwards to be able to do the backpropagation. In #2 you never have more than the intermediary results for 1 forward.

6 Likes

That makes sense.

Also, you wrote

# current graph is appended to existing graph
loss = loss + current_loss

I thought the loss would just be a scalar? But is it actually the entire graph?

loss here is a Variable containing a single element, and it has associated to it, all the history of the computations that were made to be able to backpropagate.

2 Likes

Where is this history stored exactly? It seems like it’s stored outside the variable. Let’s say I create two loss functions like so:

B = 8
linear = nn.Linear(5, 1)
x = Variable(torch.ones(B, 5))
y = linear(x)
loss_1 = 10 - y.sum()
loss_2 = 5 - y.sum()

Now as soon as I backpropagate loss_1, buffers are cleared.

loss_1.backward()

Backpropagating on loss_2 will give an error now:

loss_2.backward()  #gives an error

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

So it seems like the history is stored outside loss_1 and loss_2 both and is not deleted after calling backward() if retain_graph is True.

EDIT: Is it correct to assume that a new graph is created at the step y = linear(x)? In this case, can it be presumed that those buffers (or history) reside(s) in y and is referred to by subsequent Variables like loss_1 and loss_2?

1 Like

@nivter I made a short video a while back that might shed some light on your questions potentially, https://www.youtube.com/watch?v=4F2LfiY8JLo

I am new to Pytorch so I might have misunderstood. The three approaches don’t look equivalent to me. I did a bit test with the first two. Bascially to fit four 1s to 7 and difference is that the first script changes weights every step while the second every 5 steps.

import torch
import torch.nn as nn

torch.manual_seed(1)

model = nn.Sequential(
nn.Linear(4, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 1, bias=False),
)

x = torch.ones(4)
y0 = torch.tensor(7.)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

optimizer.zero_grad()
for i in range(200):

y = model(x)
loss = loss_fn(y, y0)

if loss.item() < 1e-5:
    print(f'after {i} steps')
    break
  
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(‘y:’, y)
for label, value in model.state_dict().items():
print(label)
print(value)

after 126 steps
y: tensor([6.9988], grad_fn=)
0.weight
tensor([[ 0.8411, 0.3628, 0.4865, 0.8181],
[-0.4707, 0.2999, -0.1029, 0.2544],
[-0.0641, -0.1948, 0.0051, -0.1089],
[ 0.1826, -0.1949, -0.0365, -0.0450],
[ 0.6402, 0.5657, 1.0048, 0.7233],
[-0.1862, -0.3020, -0.0838, -0.2157],
[ 0.4406, 0.6248, 0.8989, 0.8726],
[-0.6859, 0.1128, -0.0575, 0.2771]])
2.weight
tensor([[ 0.8684, -0.3221, -0.2251, -0.1705, 0.9041, -0.0589, 0.7641, 0.0078]])

import torch
import torch.nn as nn

torch.manual_seed(1)

model = nn.Sequential(
nn.Linear(4, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 1, bias=False),
)

x = torch.ones(4)
y0 = torch.tensor(7.)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

optimizer.zero_grad()
for i in range(5000):

y = model(x)
loss = loss_fn(y, y0)

if loss.item() < 1e-5:
    print(f'after {i} steps')
    break

loss.backward()

if(i+1) % 5 == 0:
    optimizer.step()
    optimizer.zero_grad()

print(‘y:’, y)
for label, value in model.state_dict().items():
print(label)
print(value)

after 630 steps
y: tensor([6.9988], grad_fn=)
0.weight
tensor([[ 0.8411, 0.3628, 0.4865, 0.8181],
[-0.4707, 0.2999, -0.1029, 0.2544],
[-0.0641, -0.1948, 0.0051, -0.1089],
[ 0.1826, -0.1949, -0.0365, -0.0450],
[ 0.6402, 0.5657, 1.0048, 0.7233],
[-0.1862, -0.3020, -0.0838, -0.2157],
[ 0.4406, 0.6248, 0.8989, 0.8726],
[-0.6859, 0.1128, -0.0575, 0.2771]])
2.weight
tensor([[ 0.8684, -0.3221, -0.2251, -0.1705, 0.9041, -0.0589, 0.7641, 0.0078]])

As you can see, the first program converges five times faster than the second one. Result would be similar if I change input to random numbers. It seems to me with the second approach weights are updated using an average gradients. So the first approach is more efficient.

one more question, why is the third approach takes more memory? Say,

loss_sum += loss

import sys
sys.getsizeof(loss)
sys.getsizeof(loss_sum)

I would get same results. (72 in my case). So where is the associated history stored? is there a property or function to get them? Thanks

1 Like

Hi,

In your code, the first one effectively uses a batch size of 1 while the second one uses a batch size of 5.
In this particular case where all the samples are the same, batch size is useless as the averages gradients for the batch will be the same as the gradients for one sample. and so it’s just faster to run with a batch size of 1.
Note that this is not true any more as soon as your inputs are actually different from each other and working with batches allows to get “better” gradients.

The last one uses more memory because it has to keep around all the history for all the elements in the batch, not just the last one.
This is saved by the autograd engine backend and you cannot measure it’s size with sys.getsizeof.

1 Like

@albanD: Which option is same as iter_size in caffe that very popular in Deeplab? Thanks

You need to change the inner check from:

if (i+1)%10 == 0:

to

if (i+1)%iter_size == 0:

Thanks but I meant option1, option2 or option3 in your answer will reproduce close performance with iter_size option in caffe?

All three answer compute the exact same gradients so they will be the same as using caffe with iter_size and a batch_size in caffe of batch_size / iter_size the same way in example 2 and 3 the batch size is reduced compared to example 1.

2 Likes

Hello, as u described above

Indeed, in the second case you will use much more memory. Indeed, for the 64 iterations, you will create a single graph that just keep growing, and so you will use more and more memory.

Why is the size of the size-64 computation graph keep growing? Should not it be const size as the computation and the input size keep const?

1 Like