I have this simplified code snippet, which loads an image and feed to a model of 1 CNN layer.
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 3, padding='same').half()
def forward(self, x):
return self.conv(x)
def main(cfg):
model = Model().cuda()
dataset = cfg.dataset
optimizer = optim.AdamW(model.parameters(), lr=cfg.learning_rate)
train_dataloader = DataLoader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
shuffle=False,
pin_memory=True
)
p = next(model.parameters())
for epoch in range(cfg.max_epochs):
for image, _ in train_dataloader:
print(p[0, 0, 0, 0])
image = image.to('cuda').half()
image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
output = model(image)
loss = F.mse_loss(output, image)
loss.backward()
optimizer.zero_grad()
optimizer.step()
class TestConfig:
max_epochs = 10000
root = "./temp/"
batch_size = 1
num_workers = 0
transform = transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
dataset = torchvision.datasets.CIFAR10(
root='/mnt/HDD3/khanh/temp/',
train=True,
download=True,
transform=transform
)
dtype = torch.float16
device = 'cuda'
learning_rate = 1
if __name__ == '__main__':
cfg = TestConfig()
main(cfg)
I noticed that when calling optimizer.step()
then optimizer.zero_grad()
, the code works properly (loss decreases and the model converge).
But when I call zero_grad()
then step()
, then p.grad
will be 0 after zero_grad()
(which is expected), but p[0, 0, 0, 0]
will become nan after step()
.
Is this an expected behaviour? Since to my understanding, calling zero_grad()
before step()
should have the effect of not updating the weights at all.