Loss.backward :Grad is None

I bulit a LSTM net work,and used nn.MSELoss().But it returned 0.I don’t know why made it return 0.I wish for help.

import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt
import csv

from torch.autograd import Variable

closeSet=[];

with open("data.csv",'r') as csvfile:
    reader=csv.reader(csvfile)

    n=-1;
    for row in reader :
        
        n=n+1;
        
        if(n == 0):
            continue;
        
        closeSet.append(float(row[6]));
closeSet=torch.Tensor(closeSet);

average=closeSet.mean();
print(average);

still=closeSet.std();
print(still);

closeSet=(closeSet-average)/still;
print(closeSet)

train_data=closeSet[:150];
class LSTM(nn.Module):
    def __init__(self,input_size,hidden_size,num_layer,output_size):
        super(LSTM, self).__init__()
        
        self.input_size=input_size;
        self.hidden_size=hidden_size;
        self.num_layer=num_layer;
        self.output_size=output_size;
        
        self.Wfh,self.Wfx,self.bf=self.init_Weight_bias_gate();
        self.Wih,self.Wix,self.bi=self.init_Weight_bias_gate();
        self.Woh,self.Wox,self.bo=self.init_Weight_bias_gate();
        self.Wch,self.Wcx,self.bc=self.init_Weight_bias_gate();
        self.Wy=nn.Parameter(torch.randn(self.hidden_size,self.output_size,requires_grad=True));
        self.by=nn.Parameter(torch.randn(self.output_size,requires_grad=True));
        
        self.hList=[];
        self.cList=[];

        self.times=0;
        
        self.f_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.i_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.o_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.ct_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.h_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.c_=torch.zeros(1,self.hidden_size,requires_grad=True);
        
        self.y_=torch.zeros(1,self.hidden_size,requires_grad=True);

        self.hList.append(self.h_);
        self.cList.append(self.c_);

    def init_Weight_bias_gate(self):
        return (nn.Parameter(torch.randn(self.hidden_size,self.hidden_size),requires_grad=True),
                nn.Parameter(torch.randn(self.input_size,self.hidden_size),requires_grad=True),
                nn.Parameter(torch.randn(self.hidden_size),requires_grad=True))
                
    def forward(self,x):
        self.times+=1;

        self.f_=torch.sigmoid(self.hList[-1] @ self.Wfh + x @ self.Wfx + self.bf);
        
        self.i_=torch.sigmoid(self.hList[-1] @ self.Wih + x @ self.Wix + self.bi);
        
        self.ct_=torch.tanh(self.hList[-1] @ self.Wch + x @ self.Wcx + self.bc);

        self.o_=torch.sigmoid(self.hList[-1] @ self.Woh + x @ self.Wox + self.bo);
        
        self.c_=self.f_ * self.cList[-1] + self.i_ * self.ct_;
        
        self.h_=self.o_ * torch.tanh(self.c_);
        self.y_=self.hList[-1] @ self.Wy + self.by;
        
        self.f_.requires_grad_(True);
        self.i_.requires_grad_(True);
        self.ct_.requires_grad_(True);
        self.o_.requires_grad_(True);
        self.c_.requires_grad_(True);
        self.h_.requires_grad_(True);
        
        self.y_.requires_grad_(True);
        
        self.cList.append(self.c_);
        self.hList.append(self.h_);
        
        return self.y_;
    
    def reset(self):
        self.times=0;
        
        self.hList=[torch.zeros(1,self.hidden_size,requires_grad=True),];
        self.cList=[torch.zeros(1,self.hidden_size,requires_grad=True),];
l=LSTM(150,1,1,150);

criterion = nn.MSELoss();

optimizer = torch.optim.Adam(l.parameters(), lr=0.01)
for i in range(10):
    x=train_data;
    x=x.clone().detach().unsqueeze(0);
    x.requires_grad_(True)
    
    y_output=l.forward(x);
    
    y_true=train_data;
    y_true=y_true.clone().detach();
    y_true.requires_grad_(True)
    
    loss=criterion(l.forward(x), y_true);
    
    optimizer.zero_grad();
    
    loss.backward();
    optimizer.step();
    
    print("Wfh.grad:");
    print(l.Wfh.grad)
    
    l.reset();

    lossList.append(loss.item())

I cannot reproduce the issue and get a valid gradient for l.Wfh.grad:

l=LSTM(150,1,1,150);
criterion = nn.MSELoss();
optimizer = torch.optim.Adam(l.parameters(), lr=0.01)
x = torch.randn(10, 10, 150)
y_output=l.forward(x);

y_true = torch.randn_like(y_output)
loss=criterion(l.forward(x), y_true);

optimizer.zero_grad();
loss.backward();
optimizer.step();

print("Wfh.grad:");
print(l.Wfh.grad)
# tensor([[0.0018]])

However, your code has a few issues:

  • Variables are deprecated since PyTorch 0.4 so remove their usage and use tensors directly instead,
  • I don’t know why you would need to call .requires_grad_(True) on parameters inside the forward method,
  • your output seems to change in shape depending on the iteration, which I’m also unsure if this is desired,
  • don’t call the model.forward method as it will skip calling into potentially registered hooks, but call the model directly instead: model(input).

Hello,thank you for your help.

I had deleted .requires_grad_(True) on parameters inside the forward method and used model(input).But excpet for the gradient for Wix,the gradient for bi,the gradient for Wy and by,ohter parameter’s gradients are still 0.

I don’t know what lead to the result.And I can’t understand your third suggestion.I tried to output the shape of y_,it is torch.Size([1, 150]).Can you tell me in detail?

class LSTM(nn.Module):
    def __init__(self,input_size,hidden_size,num_layer,output_size):
        super(LSTM, self).__init__()

        self.input_size=input_size;
        self.hidden_size=hidden_size;
        self.num_layer=num_layer;
        self.output_size=output_size;
        
        self.Wfh,self.Wfx,self.bf=self.init_Weight_bias_gate();
        self.Wih,self.Wix,self.bi=self.init_Weight_bias_gate();
        self.Woh,self.Wox,self.bo=self.init_Weight_bias_gate();
        self.Wch,self.Wcx,self.bc=self.init_Weight_bias_gate();
        
        self.Wy=nn.Parameter(torch.randn(self.hidden_size,self.output_size));
        self.by=nn.Parameter(torch.randn(self.output_size));
        
        self.hList=[];
        self.cList=[];
        
        self.times=0;
        
        self.f_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.i_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.o_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.ct_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.h_=torch.zeros(1,self.hidden_size,requires_grad=True);
        self.c_=torch.zeros(1,self.hidden_size,requires_grad=True);
        
        self.y_=torch.zeros(1,self.hidden_size,requires_grad=True);
  
        self.hList.append(self.h_);
        self.cList.append(self.c_);
        
    def init_Weight_bias_gate(self):
        return (nn.Parameter(torch.randn(self.hidden_size,self.hidden_size)),
                nn.Parameter(torch.randn(self.input_size,self.hidden_size)),
                nn.Parameter(torch.randn(self.hidden_size)))
                
    def forward(self,x):
        self.times+=1;
        
        self.f_=torch.sigmoid(self.hList[-1] @ self.Wfh + x @ self.Wfx + self.bf);
        
        self.i_=torch.sigmoid(self.hList[-1] @ self.Wih + x @ self.Wix + self.bi);
        
        self.ct_=torch.tanh(self.hList[-1] @ self.Wch + x @ self.Wcx + self.bc);
        
        self.o_=torch.sigmoid(self.hList[-1] @ self.Woh + x @ self.Wox + self.bo);
        
        self.c_=self.f_ * self.cList[-1] + self.i_ * self.ct_;
        
        self.h_=self.o_ * torch.tanh(self.c_);
        
        self.y_=self.hList[-1] @ self.Wy + self.by;
        
        self.cList.append(self.c_);
        self.hList.append(self.h_);
        
        print("self.y_:");
        print(self.y_.shape);
        
        return self.y_;
    
    def reset(self):
        self.times=0;
        
        self.hList=[torch.zeros(1,self.hidden_size,requires_grad=True),];
        self.cList=[torch.zeros(1,self.hidden_size,requires_grad=True),];
l=LSTM(150,1,1,150);

criterion = nn.MSELoss();

optimizer = torch.optim.Adam(l.parameters(), lr=0.01)

for i in range(10):
    x=train_data;
    x=x.clone().detach().unsqueeze(0);
    
    x.requires_grad_(True)   
    
    y_output=l.forward(x);
    
    y_true=train_data;
    y_true=y_true.clone().detach();
    
    y_true.requires_grad_(True)
   
    loss=criterion(l(x), y_true);
    
    optimizer.zero_grad();
    
    loss.backward();
    optimizer.step();
    
    for name, parms in l.named_parameters():
            print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
                 ' -->grad_value:', parms.grad, '-->value:', parms)

    l.reset();

    lossList.append(loss.item())

I still cannot reproduce the issue, so it would be great if you could post a minimal and executable code snippet which would show the issue.

In this code you can see that the output shape changes in the second iteration:

l=LSTM(150,1,1,150);
criterion = nn.MSELoss();
optimizer = torch.optim.Adam(l.parameters(), lr=0.01)

x = torch.randn(10, 10, 150)
y_output = l(x);
print(y_output.shape)
# torch.Size([1, 150])

y_output = l(x);
print(y_output.shape)
# torch.Size([10, 10, 150])

which seems at least uncommon.

The rest of the code also shows valid gradients:

y_true = torch.randn_like(y_output)
loss=criterion(l(x), y_true);

optimizer.zero_grad();
loss.backward();
optimizer.step();

print("Wfh.grad:");
print(l.Wfh.grad)
# tensor([[0.0018]])

for name, parms in l.named_parameters():
        print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
             ' -->grad_value:', parms.grad.abs().sum(), '-->value:', parms.abs().sum())
# -->name: Wfh -->grad_requirs: True  -->grad_value: tensor(0.0001) -->value: tensor(0.8921, grad_fn=<SumBackward0>)
# -->name: Wfx -->grad_requirs: True  -->grad_value: tensor(0.0883) -->value: tensor(121.5771, grad_fn=<SumBackward0>)
# -->name: bf -->grad_requirs: True  -->grad_value: tensor(0.0015) -->value: tensor(1.7418, grad_fn=<SumBackward0>)
# -->name: Wih -->grad_requirs: True  -->grad_value: tensor(1.4800e-05) -->value: tensor(0.3881, grad_fn=<SumBackward0>)
# -->name: Wix -->grad_requirs: True  -->grad_value: tensor(0.5420) -->value: tensor(120.5190, grad_fn=<SumBackward0>)
# -->name: bi -->grad_requirs: True  -->grad_value: tensor(0.0131) -->value: tensor(0.3184, grad_fn=<SumBackward0>)
# -->name: Woh -->grad_requirs: True  -->grad_value: tensor(0.0020) -->value: tensor(0.9126, grad_fn=<SumBackward0>)
# -->name: Wox -->grad_requirs: True  -->grad_value: tensor(0.6173) -->value: tensor(136.6227, grad_fn=<SumBackward0>)
# -->name: bo -->grad_requirs: True  -->grad_value: tensor(0.0181) -->value: tensor(1.3139, grad_fn=<SumBackward0>)
# -->name: Wch -->grad_requirs: True  -->grad_value: tensor(0.0004) -->value: tensor(1.6613, grad_fn=<SumBackward0>)
# -->name: Wcx -->grad_requirs: True  -->grad_value: tensor(0.7321) -->value: tensor(134.0075, grad_fn=<SumBackward0>)
# -->name: bc -->grad_requirs: True  -->grad_value: tensor(0.0076) -->value: tensor(1.8784, grad_fn=<SumBackward0>)
# -->name: Wy -->grad_requirs: True  -->grad_value: tensor(0.2735) -->value: tensor(103.7549, grad_fn=<SumBackward0>)
# -->name: by -->grad_requirs: True  -->grad_value: tensor(1.5557) -->value: tensor(113.4924, grad_fn=<SumBackward0>)

Sorry for the late reply.I have found the error.
When I replace the following code

def init_Weight_bias_gate(self):
        return (nn.Parameter(torch.randn(self.hidden_size,self.hidden_size),requires_grad=True),
                nn.Parameter(torch.randn(self.input_size,self.hidden_size),requires_grad=True),
                nn.Parameter(torch.randn(self.hidden_size),requires_grad=True))

with the following code

def init_Weight_bias_gate(self):
        return (nn.Parameter(torch.randn(self.hidden_size,self.hidden_size) * 0.01),
                nn.Parameter(torch.randn(self.input_size,self.hidden_size) * 0.01),
                nn.Parameter(torch.randn(self.hidden_size) * 0.01))

The problem is solved.
Looks like a problem with random initialization parameters.

Thank you for your help.