`zero_grad` before `step` causes gradient explosion?

I have this simplified code snippet, which loads an image and feed to a model of 1 CNN layer.

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 3, padding='same').half()

    def forward(self, x):
        return self.conv(x)


def main(cfg):
    model = Model().cuda()
    dataset = cfg.dataset
    optimizer = optim.AdamW(model.parameters(), lr=cfg.learning_rate)

    train_dataloader = DataLoader(
        dataset, 
        batch_size=cfg.batch_size, 
        num_workers=cfg.num_workers,
        shuffle=False,
        pin_memory=True    
    )
    
    p = next(model.parameters())
    for epoch in range(cfg.max_epochs):
        for image, _ in train_dataloader:

            print(p[0, 0, 0, 0])

            image = image.to('cuda').half()
            image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)

            output = model(image)
            loss = F.mse_loss(output, image)

            loss.backward()
            optimizer.zero_grad()
            optimizer.step()


class TestConfig:
    max_epochs = 10000
    root = "./temp/"
    batch_size = 1
    num_workers = 0
    transform = transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
     )
    dataset = torchvision.datasets.CIFAR10(
        root='/mnt/HDD3/khanh/temp/',
        train=True,
        download=True,
        transform=transform    
    )
    dtype = torch.float16
    device = 'cuda'
    learning_rate = 1


if __name__ == '__main__':
    cfg = TestConfig()
    main(cfg)

I noticed that when calling optimizer.step() then optimizer.zero_grad(), the code works properly (loss decreases and the model converge).

But when I call zero_grad() then step(), then p.grad will be 0 after zero_grad() (which is expected), but p[0, 0, 0, 0] will become nan after step().

Is this an expected behaviour? Since to my understanding, calling zero_grad() before step() should have the effect of not updating the weights at all.

Could you post a minimal and executable code snippet reproducing the issue?

Hi, I have edited the full, executable code snippet.

It seems that the problem is a combination of multiple reasons. The issue disappears in these 2 cases:

  1. The dtype is set to torch.float32 instead of torch.float16.
  2. AdamW is replaced with SGD or is set with a large eps value (currently it works with eps=0.1).

I cannot reproduce the issue using your code, but I also executed it for 4 epochs only instead of the specified 10k.

Manually using float16 can easily cause overflows and is thus not recommended. Use torch.cuda.amp instead for mixed-precision training.

Stateful optimizers (such as AdamW) will use running stats to update parameters even if their gradient is zero so you might want to check if these stats are overflowing for some reason.