Gradients are zero, unsure why

Hello,

I am converting a simple CNN to PyTorch that I had working in Keras. However the gradients are zero when training in PyTorch. I have checked the data importing, model building, and training portions, but can’t seem to figure out what is causing this. I suspect it is a simple error in the training portion of the script or how I am setting everything up for Autograd.

Below are the script and output from training.

Thanks for your time!

Note: the input images I am working with are of size 2 x 128

Script:

import numpy as np
import time
from torch.utils.data import DataLoader, random_split
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

# Import data here

BATCH_SIZE = 1024
SPLIT = .5

# Split into train/val sets
total = len(dataset)
lengths = [int(len(dataset)*SPLIT)]
lengths.append(total - lengths[0])
print("Splitting into {} train and {} val".format(lengths[0], lengths[1]))
train_set, val_set = random_split(dataset, lengths)

# Setup dataloaders
train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_set, batch_size=BATCH_SIZE)

#Setup the model as a class
class CNN(nn.Module):   
    def __init__(self):
        super(CNN, self).__init__()

        # define the dropout layer
        self.dropout  = nn.Dropout(p = 0.5)
        
        # add a zero to the left and right hand sides of the input
        self.zero_pad = nn.ZeroPad2d((1,1,0,0))
        
        # convolutional layers w/ weight initialization
        self.conv1 = nn.Conv2d(1, 256, kernel_size=(1,3), stride=1, padding = (0,1), bias = True)
        self.conv2 = nn.Conv2d(256, 80, kernel_size=(2,3), stride=1, padding = (0,1), bias = True)
        
        # dense layers w/ weight initialization
        self.dense1 = nn.Linear(80*130, 256, bias =True)
        self.dense2 = nn.Linear(256,11, bias = True)

        
    # Defining the forward pass    
    def forward(self, x):
        x = self.zero_pad(x)
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.dense1(x))
        x = self.dense2(x)
        return x
    
model = CNN()

criterion = nn.CrossEntropyLoss()
# specify parameters in optimizer to match Keras
optimizer = optim.Adam(params = model.parameters())
# checking if GPU is available
if torch.cuda.is_available():
    print('CUDA is available')
    use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = False
model = model.to(device)
criterion = criterion.to(device)

# patience for early stopping
patience = 5

epochs = 100
train_losses = []
valid_losses = []
val_best = np.Inf
best_ep = 0
# Counter for early stopping
patience_counter = 0

start_all = time.time()
for e in range(epochs):
    start_ep = time.time()
    running_loss = 0
    rl = 0
    model.train()
    for data, labels in train_dataloader:
        data = data.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # predictions on the training data
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    with torch.no_grad():
        for dv,lv in val_dataloader:
            dv = dv.to(device)
            lv = lv.to(device)
            model.eval()
            op = model(dv)
            ll = criterion(op, lv)
            rl += ll.item()
    
    train_loss = running_loss/len(train_dataloader)
    val_loss = rl/len(val_dataloader)
    train_losses.append(train_loss)
    valid_losses.append(val_loss)

    if val_loss<val_best:
        val_best = val_loss
        model_name = 'best_model_test.pt'
        torch.save(model,model_name)
        best_ep = e
        patience_counter = 0
    else: #early stopping
        patience_counter += 1
        if patience_counter == patience-1:
            end_ep = time.time()
            print('Epoch: '+str(e))
            print(' - ' + str(round(end_ep-start_ep,3)) + 's - train_loss: '+\
                  str(round(train_loss,4))+' - val_loss: '\
                  +str(round(val_loss,4)))
            break
            
    end_ep = time.time()
    print('Epoch: '+str(e))
    print(' - ' + str(round(end_ep-start_ep,3)) + 's - train_loss: '+\
      str(round(train_loss,4))+' - val_loss: '\
      +str(round(val_loss,4)))
end_all = time.time()
print('Total training time = ' + str(round((end_all-start_all)/60,3)) + ' minutes')

Output:
torch_train

1 Like

What makes you say your gradients are zeros?

It kind of looks to me that your model is training (since train_loss goes down ever so slightly). Try increasing your learning rate and see if anything changes. You can also iterate over your model parameters and check to see if their gradients are actually zero.

1 Like

@ayalaa2 thanks for your time.

The gradients are nearly zero, but that is not quite the point. The learning rate and all other parameters are the exact same from the Keras model I am referring to. Needless to say the gradients in Keras were much larger.

I have gotten this model to replicate performance I had in Keras, but I made a few changes in the code and haven’t been able to since. The script I have posted is nearly the exact code I had at that point in time, I can’t seem to troubleshoot it.

I tagged this as autograd in case the issue had to do with how I am setting up the optimizer, as I am new to pytorch.

Hi,

This may happen due to difference in default values in models. Here is some differences that I can think of as I do not know how Keras works.

  1. PyTorch DataLoader does not shuffle the data by default but in your task, it seems you can do it, so this can help a lot.
  2. I do not know how you constructed dataset, but normalization might be needed if already isn’t.
  3. This might not be much important but PyTorch uses uniform weight initalization for linear and conv layers. You might want to make sure this also correct for your case.

And again, I think you need to check the default value in every object you have initialized here as there might be different conventions.

Bests

Thank you for your time @Nikronic.

These are things I originally thought were wrong, as I was not used to the Dataloader format, but my data is being loaded properly. The inputs and labels correctly correspond, and are shuffled.

Regarding the weight initialization, you are correct they need to be changed, but this does not change the performance. I did not share that model to simplify the original post. Please find it below:

class CNN(nn.Module):   
    def __init__(self):
        super(CNN, self).__init__()

        # define the dropout layer
        self.dropout  = nn.Dropout(p = 0.5)
        
        # add a zero to the left and right hand sides of the input
        self.zero_pad = nn.ZeroPad2d((1,1,0,0))
        
        # convolutional layers w/ weight initialization
        self.conv1 = nn.Conv2d(1, 256, kernel_size=(1,3), stride=1, padding = (0,1), bias = True)
        torch.nn.init.xavier_uniform_(self.conv1.weight)
        self.conv2 = nn.Conv2d(256, 80, kernel_size=(2,3), stride=1, padding = (0,1), bias = True)
        torch.nn.init.xavier_uniform_(self.conv2.weight)
        
        # dense layers w/ weight initialization
        self.dense1 = nn.Linear(80*130, 256, bias =True)
        torch.nn.init.kaiming_normal_(self.dense1.weight, nonlinearity='relu')
        self.dense2 = nn.Linear(256,11, bias = True)
        torch.nn.init.kaiming_normal_(self.dense2.weight, nonlinearity='sigmoid')
        
    # Defining the forward pass    
    def forward(self, x):
        x = self.zero_pad(x)
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.dense1(x))
        x = self.dense2(x)
        return x
    
model = CNN()

I am under the inclination the error is in the forward of my model or the training script itself.

Thanks again for your time

1 Like