Why do we need to set the gradients manually to zero in pytorch?

Here are three equivalent code, with different runtime/memory comsumption.
Assume that you want to run sgd with a batch size of 100.
(I didn’t run the code below there might be some typos, sorry in advance)

1: single batch of 100 (least runtime, more memory)

# some code
# Initialize dataset with batch size 100
for input, target in dataset:
    pred = net(input)
    loss = crit(pred, target)
    # one graph is created here
    opt.zero_grad()
    loss.backward()
    # graph is cleared here
    opt.step()

2: multiple small batches of 10 (more runtime, least memory)

# some code
# Initialize dataset with batch size 10
opt.zero_grad()
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    loss = crit(pred, target)
    # one graph is created here
    loss.backward()
    # graph is cleared here
    if (i+1)%10 == 0:
        # every 10 iterations of batches of size 10
        opt.step()
        opt.zero_grad()

3: accumulate loss for multiple batches (more runtime, more memory)

# some code
# Initialize dataset with batch size 10
loss = 0
for i, (input, target) in enumerate(dataset):
    pred = net(input)
    current_loss = crit(pred, target)
    # current graph is appended to existing graph
    loss = loss + current_loss
    if (i+1)%10 == 0:
        # every 10 iterations of batches of size 10
        opt.zero_grad()
        loss.backward()
        # huge graph is cleared here
        opt.step()

It should be clear that case 3 is not what you want.
The choice between case 1 and 2 is a tradeoff between memory and speed so that depends on what you want to do.
Note that if you can fit a batch size of 50 in your memory, you can do a variation of case 2 with batch size of 50 and update every 2 iterations.

91 Likes