Where should I place .zero_grad()?

Hi all,

I just want to be sure, where should I use the .zero_grad() function?
In the official MNIST example, the .zero_grad() function is used in the beginning of the training loop.

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break

from here: https://github.com/pytorch/examples/blob/a74badde33f924c2ce5391141b86c40483150d5a/mnist/main.py#L37

Also, in the official tutorials, the zero_grad() function is used right before the backward() function.

for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

In my understanding, the first way makes sense, because we want to have gradient only for the current batch…

Both approaches are valid for the standard use case, i.e. if you do not want to accumulate gradients for multiple iterations.
You can thus call optimizer.zero_grad() everywhere in the loop but not between the loss.backward() and optimizer.step() operation.

8 Likes