Weights not updating using optimizer.step() or when manually updating the parameters

Hi everyone,
I am new to pytorch and I am writing a simple cat-dog classifier - my dataloader and trainer scripts are given below :
Dataloader :

 import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import random
import torch
import skimage

from torch.utils.data.dataset import Dataset

class CatDogDataset(Dataset):
    def __init__(self, csv_file, root_dir):
        self.df = pd.read_csv(csv_file)
        self.root_dir = root_dir
#        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def __len__(self):
        return len(self.df)
    
    # can add random rotation later on
    
    def __getitem__(self,idx):
        img_path = self.df.iloc[idx,0]
        image = skimage.io.imread(img_path)
        image = skimage.transform.resize(image,(3,128,128)).astype(np.float32)
        image = (image-image.mean())/image.std()
        image = image.astype(np.float32)
        gt = np.asarray(self.df.iloc[idx,1]).astype(np.float32)
        gt = torch.tensor(gt)
        sample = {'image_data' : image,'gt' : gt}
        return sample

Trainer :

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import torchvision
import csv 
from torch.autograd import Variable
import glob
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from random import shuffle 
import glob
import os
from data_loader import CatDogDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchsummary import summary
import sys

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,base_filters,3, padding = 1) 
        self.pool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(base_filters,base_filters*2,3, padding = 1)
        self.pool2 = nn.MaxPool2d(2)
        
        self.conv3 = nn.Conv2d(base_filters*2,base_filters*4,3, padding = 1)
        self.pool3 = nn.MaxPool2d(2)
        
        self.conv4 = nn.Conv2d(base_filters*4,base_filters*8,3, padding = 1)
        self.pool4 = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(base_filters*8*8*8,120)
        self.fc2 = nn.Linear(120,100)
        self.fc3 = nn.Linear(100,1)
            
        
    def forward(self,x):
        x  = self.pool1(F.relu(self.conv1(x)))
        x  = self.pool2(F.relu(self.conv2(x)))
        x  = self.pool3(F.relu(self.conv3(x)))
        x  = self.pool4(F.relu(self.conv4(x)))
 
        x = x.view(-1,base_filters*8*8*8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        #x = F.softmax(x)
        return x

base_filters = 50
batch_size = 1
net = Net()

loss_fn = torch.nn.MSELoss()
dataset_train = CatDogDataset("/home/megh/work/catdog/CD_data.csv","/home/megh/work/catdog/")
train_loader = DataLoader(dataset_train, batch_size = batch_size, shuffle= True, num_workers=1)

#optimizer = optim.Adam(net.parameters(), lr = 0.1)
net.cuda()

summary(net, input_size=(3, 128, 128))

net  = net.train()
tmp = []
lr = 1
for batch_idx, (animal) in enumerate(train_loader):      
    image = animal['image_data']
    image = image.cuda()
    gt = animal['gt']
    gt = gt.cuda()
    image, gt = Variable(image, requires_grad = True), Variable(gt, requires_grad = True)
    #optimizer.zero_grad()
    output = net(image)
    loss =  loss_fn(output,gt)
    loss.backward()
    #optimizer.step()
    
    for p in net.parameters():
        print(type(p.grad))
        p.requires_grad_()
        print(type(p.grad))
        p = p - lr*(p.grad)
        p.requires_grad_()
        print(type(p.grad))
        #print("helo")
        p.grad.zero_()
    print(loss)
    tmp.append(net.parameters())

The error what I get is :

<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'NoneType'>
^^ prints this before the error
 AttributeError: 'NoneType' object has no attribute 'zero_'

I am not understanding why the datatype of gradient is becoming NoneType.
I have verified that the weights are not getting updated by storing the network parameters after each iteration (by commenting the p.grad.zero_()) and verifying that they are the same after every iteration.
However, when I iterate through the network parameters by writing a small script in my spyder kernel, there seems to be no issue, and the dtype doesnt become none after I do some operations on the parameters.
Any help will be greatly appreciated
Thanks

Sorry the code isnt formatted properly

Could you try to update the parameters inplace:

p.sub_(lr* p.grad)

This would avoid creating a new tensor named p, which doesn’t have the .grad attribute yet (None by default before the first backward call).
Also, you don’t need to call p.requires_grad_() inside the loop.

Note that Variables are deprecated and you can just use tensors now.

PS: I’ve formatted your code. If you would like to post code snippets, you can wrap your code into three backticks ``` :wink:

Hi, Thanks so much for your reply. Yes, I tried inplace operation as well. But it gives the following error:
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

So well, now I changed my code to this (Commented out the manual weight update and used the optimizer object instead):

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import torchvision
import csv 
from torch.autograd import Variable
import glob
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from random import shuffle 
import glob
import os
from data_loader import CatDogDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchsummary import summary
import sys

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,base_filters,3, padding = 1) 
        self.pool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(base_filters,base_filters*2,3, padding = 1)
        self.pool2 = nn.MaxPool2d(2)
        
        self.conv3 = nn.Conv2d(base_filters*2,base_filters*4,3, padding = 1)
        self.pool3 = nn.MaxPool2d(2)
        
        self.conv4 = nn.Conv2d(base_filters*4,base_filters*8,3, padding = 1)
        self.pool4 = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(base_filters*8*8*8,120)
        self.fc2 = nn.Linear(120,100)
        self.fc3 = nn.Linear(100,1)
            
        
    def forward(self,x):
        x  = self.pool1(F.relu(self.conv1(x)))
        x  = self.pool2(F.relu(self.conv2(x)))
        x  = self.pool3(F.relu(self.conv3(x)))
        x  = self.pool4(F.relu(self.conv4(x)))
 
        x = x.view(-1,base_filters*8*8*8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        #x = F.softmax(x)
        return x

base_filters = 50
batch_size = 1
net = Net()

loss_fn = torch.nn.MSELoss()
dataset_train = CatDogDataset("/home/megh/work/catdog/CD_data.csv","/home/megh/work/catdog/")
train_loader = DataLoader(dataset_train, batch_size = batch_size, shuffle= True, num_workers=1)

optimizer = optim.Adam(net.parameters(), lr = 0.1)
net.cuda()

summary(net, input_size=(3, 128, 128))

net.train()
tmp = []
lr = 1
for batch_idx, (animal) in enumerate(train_loader):      
    image = animal['image_data']
    image = image.cuda()
    gt = animal['gt']
    gt = gt.cuda()
    image, gt = Variable(image, requires_grad = True), Variable(gt, requires_grad = True)
    output = net(image)
    loss =  loss_fn(output,gt)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# =============================================================================
#     for p in net.parameters():
#         p.sub_(lr* p.grad)
#         p.grad.zero_()
# =============================================================================
    print(loss)
    tmp.append(net.parameters())

I ran this for about 141 iterations i.e 141 images and ran the following small code snippet to check if the weights of the parameters are updating at all :

for p,q in zip(tmp[0],tmp[140]):
    print(p==q)

Here tmp is the list-element which stores the model parameters every iteration (given in code)
The output that it gives is a series of tensors with all elements being equal to 1
Please let me know why the optimizer isn’t updating the weights. (I’ve spent like 24 hours trying to debug this :thinking::thinking:)
Thanks a lot!

Could you wrap the code into a torch.no_grad() block:

with torch.no_grad():
    for p in net.parameters():
        p,sub_(...

Make sure to call .clone() on the parameters, if you want to store them for debugging.
Otherwise you’ll store the reference, which will not show any changes in the parameter’s value.

Hi,
I did what you said. Now my code (the parameter update part looks like):

 with torch.no_grad():
        for p in net.parameters():
            p.sub_(lr* p.grad)
            p.grad.zero_()
            p = p.clone()

Still shows that the parameters are equal :frowning_face:

for p,q in zip(tmp[0],tmp[20]):
    print((p.data.sum())==(q.data.sum()))
    
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)
tensor(1, device='cuda:0', dtype=torch.uint8)

Okay, I think it is working now. I cloned and checked that the parameters are not the same. Thanks a lot.