Where should I place .zero_grad()?

Hi all,

I just want to be sure, where should I use the .zero_grad() function?
In the official MNIST example, the .zero_grad() function is used in the beginning of the training loop.

def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break

from here: https://github.com/pytorch/examples/blob/a74badde33f924c2ce5391141b86c40483150d5a/mnist/main.py#L37

Also, in the official tutorials, the zero_grad() function is used right before the backward() function.

for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    if t % 100 == 99:
        print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

In my understanding, the first way makes sense, because we want to have gradient only for the current batch…

1 Like

Both approaches are valid for the standard use case, i.e. if you do not want to accumulate gradients for multiple iterations.
You can thus call optimizer.zero_grad() everywhere in the loop but not between the loss.backward() and optimizer.step() operation.

12 Likes