I built a LSTM net work,but other parameters’s gradents are right,except the gradents of Wch、Wcx and bc.They are 0.I don’t know what make them become 0.Wish for help.
class LSTM(nn.Module):
def __init__(self,input_size,hidden_size,num_layer,output_size):
super(LSTM, self).__init__()
self.input_size=input_size;
self.hidden_size=hidden_size;
self.num_layer=num_layer;
self.output_size=output_size;
self.Wfh,self.Wfx,self.bf=self.init_Weight_bias_gate();
self.Wih,self.Wix,self.bi=self.init_Weight_bias_gate();
self.Woh,self.Wox,self.bo=self.init_Weight_bias_gate();
self.Wch,self.Wcx,self.bc=self.init_Weight_bias_gate();
self.hList=[];
self.cList=[];
self.times=0;
self.f_=torch.zeros(1,self.hidden_size,requires_grad=True);
self.i_=torch.zeros(1,self.hidden_size,requires_grad=True);
self.o_=torch.zeros(1,self.hidden_size,requires_grad=True);
self.ct_=torch.zeros(1,self.hidden_size,requires_grad=True);
self.h_=torch.zeros(1,self.hidden_size,requires_grad=True);
self.c_=torch.zeros(1,self.hidden_size,requires_grad=True);
self.hList.append(self.h_);
self.cList.append(self.c_);
self.predict=nn.Linear(hidden_size, output_size);
def init_Weight_bias_gate(self):
return (nn.Parameter(torch.randn(self.hidden_size,self.hidden_size)),
nn.Parameter(torch.randn(self.input_size,self.hidden_size)),
nn.Parameter(torch.randn(self.hidden_size)))
def forward(self,x):
self.times+=1;
self.f_=torch.sigmoid(self.hList[-1] @ self.Wfh + x @ self.Wfx + self.bf);
self.i_=torch.sigmoid(self.hList[-1] @ self.Wih + x @ self.Wix + self.bi);
self.ct_=torch.tanh(self.hList[-1] @ self.Wch + x @ self.Wcx + self.bc);
self.o_=torch.sigmoid(self.hList[-1] @ self.Woh + x @ self.Wox + self.bo);
self.c_=self.f_ * self.cList[-1] + self.i_ * self.ct_;
self.h_=self.o_ * torch.tanh(self.c_);
self.cList.append(self.c_);
self.hList.append(self.h_);
return self.predict(self.h_);
def reset(self):
self.times=0;
self.hList=[torch.zeros(1,self.hidden_size,requires_grad=True),];
self.cList=[torch.zeros(1,self.hidden_size,requires_grad=True),];
l=LSTM(150,1,1,150);
criterion = nn.MSELoss();
optimizer = torch.optim.Adam(l.parameters(), lr=0.1)
for i in range(10):
x=train_data;
x=x.clone().detach().unsqueeze(0);
x.requires_grad_(True);
y_output=l.forward(x);
y_true=train_data;
y_true=y_true.clone().detach();
y_true.requires_grad_(True)
loss=criterion(l(x).float(), y_true.float());
optimizer.zero_grad();
loss.backward();
optimizer.step();
for name, parms in l.named_parameters():
print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \
' -->grad_value:', parms.grad)
l.reset();
lossList.append(loss.item())
output
-->name: Wfh -->grad_requirs: True -->grad_value: tensor([[-1.2410e-05]])
-->name: Wfx -->grad_requirs: True -->grad_value: tensor([[-0.0009],
[-0.0009],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0007],
[-0.0007],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0008],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0004],
[-0.0005],
[-0.0005],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0006],
[-0.0006],
[-0.0005],
[-0.0004],
[-0.0004],
[-0.0004],
[-0.0005],
[-0.0005],
[-0.0006],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0004],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0006],
[-0.0005],
[-0.0006],
[-0.0006],
[-0.0007],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0006],
[-0.0005],
[-0.0004],
[-0.0004],
[-0.0004],
[-0.0003],
[-0.0004],
[-0.0004],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0004],
[-0.0004],
[-0.0004],
[-0.0004],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0004],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0005],
[-0.0004],
[-0.0004],
[-0.0004],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0004],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0004],
[-0.0004],
[-0.0003],
[-0.0003],
[-0.0004],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0003],
[-0.0004],
[-0.0004],
[-0.0004],
[-0.0003],
[-0.0003]])
-->name: bf -->grad_requirs: True -->grad_value: tensor([-0.0003])
-->name: Wih -->grad_requirs: True -->grad_value: tensor([[-9.1791e-05]])
-->name: Wix -->grad_requirs: True -->grad_value: tensor([[-0.0079],
[-0.0073],
[-0.0070],
[-0.0070],
[-0.0066],
[-0.0066],
[-0.0066],
[-0.0066],
[-0.0060],
[-0.0058],
[-0.0064],
[-0.0067],
[-0.0064],
[-0.0064],
[-0.0059],
[-0.0059],
[-0.0059],
[-0.0057],
[-0.0060],
[-0.0061],
[-0.0062],
[-0.0065],
[-0.0066],
[-0.0065],
[-0.0065],
[-0.0060],
[-0.0062],
[-0.0059],
[-0.0059],
[-0.0060],
[-0.0060],
[-0.0050],
[-0.0051],
[-0.0049],
[-0.0049],
[-0.0048],
[-0.0046],
[-0.0044],
[-0.0044],
[-0.0048],
[-0.0049],
[-0.0047],
[-0.0047],
[-0.0050],
[-0.0052],
[-0.0049],
[-0.0054],
[-0.0052],
[-0.0049],
[-0.0045],
[-0.0046],
[-0.0046],
[-0.0044],
[-0.0043],
[-0.0042],
[-0.0042],
[-0.0041],
[-0.0042],
[-0.0042],
[-0.0040],
[-0.0039],
[-0.0038],
[-0.0037],
[-0.0039],
[-0.0046],
[-0.0048],
[-0.0048],
[-0.0047],
[-0.0044],
[-0.0045],
[-0.0043],
[-0.0047],
[-0.0046],
[-0.0038],
[-0.0037],
[-0.0036],
[-0.0036],
[-0.0041],
[-0.0042],
[-0.0049],
[-0.0041],
[-0.0043],
[-0.0042],
[-0.0040],
[-0.0041],
[-0.0038],
[-0.0038],
[-0.0040],
[-0.0041],
[-0.0043],
[-0.0047],
[-0.0044],
[-0.0053],
[-0.0053],
[-0.0057],
[-0.0053],
[-0.0051],
[-0.0051],
[-0.0047],
[-0.0045],
[-0.0037],
[-0.0035],
[-0.0032],
[-0.0026],
[-0.0030],
[-0.0031],
[-0.0028],
[-0.0025],
[-0.0023],
[-0.0024],
[-0.0031],
[-0.0032],
[-0.0036],
[-0.0038],
[-0.0038],
[-0.0038],
[-0.0039],
[-0.0040],
[-0.0037],
[-0.0038],
[-0.0044],
[-0.0043],
[-0.0042],
[-0.0039],
[-0.0037],
[-0.0030],
[-0.0030],
[-0.0028],
[-0.0028],
[-0.0028],
[-0.0028],
[-0.0030],
[-0.0029],
[-0.0024],
[-0.0027],
[-0.0028],
[-0.0032],
[-0.0033],
[-0.0029],
[-0.0029],
[-0.0029],
[-0.0028],
[-0.0029],
[-0.0029],
[-0.0029],
[-0.0030],
[-0.0032],
[-0.0031],
[-0.0025],
[-0.0027]])
-->name: bi -->grad_requirs: True -->grad_value: tensor([-0.0024])
-->name: Woh -->grad_requirs: True -->grad_value: tensor([[-2.7305e-07]])
-->name: Wox -->grad_requirs: True -->grad_value: tensor([[-2.0400e-05],
[-1.8909e-05],
[-1.8090e-05],
[-1.8061e-05],
[-1.7140e-05],
[-1.7067e-05],
[-1.7140e-05],
[-1.6950e-05],
[-1.5400e-05],
[-1.5064e-05],
[-1.6569e-05],
[-1.7198e-05],
[-1.6628e-05],
[-1.6526e-05],
[-1.5327e-05],
[-1.5210e-05],
[-1.5210e-05],
[-1.4800e-05],
[-1.5517e-05],
[-1.5882e-05],
[-1.5926e-05],
[-1.6833e-05],
[-1.6964e-05],
[-1.6745e-05],
[-1.6672e-05],
[-1.5444e-05],
[-1.5941e-05],
[-1.5151e-05],
[-1.5210e-05],
[-1.5502e-05],
[-1.5371e-05],
[-1.2987e-05],
[-1.3075e-05],
[-1.2695e-05],
[-1.2651e-05],
[-1.2490e-05],
[-1.1920e-05],
[-1.1262e-05],
[-1.1423e-05],
[-1.2330e-05],
[-1.2622e-05],
[-1.2213e-05],
[-1.2213e-05],
[-1.2797e-05],
[-1.3499e-05],
[-1.2754e-05],
[-1.3850e-05],
[-1.3499e-05],
[-1.2549e-05],
[-1.1672e-05],
[-1.1876e-05],
[-1.1777e-05],
[-1.1352e-05],
[-1.1225e-05],
[-1.0899e-05],
[-1.0942e-05],
[-1.0715e-05],
[-1.0814e-05],
[-1.0843e-05],
[-1.0432e-05],
[-1.0063e-05],
[-9.9360e-06],
[-9.5253e-06],
[-1.0049e-05],
[-1.1834e-05],
[-1.2358e-05],
[-1.2358e-05],
[-1.2061e-05],
[-1.1437e-05],
[-1.1650e-05],
[-1.1182e-05],
[-1.2032e-05],
[-1.1976e-05],
[-9.9077e-06],
[-9.5819e-06],
[-9.1712e-06],
[-9.3836e-06],
[-1.0488e-05],
[-1.0786e-05],
[-1.2712e-05],
[-1.0573e-05],
[-1.0998e-05],
[-1.0927e-05],
[-1.0418e-05],
[-1.0715e-05],
[-9.7093e-06],
[-9.9219e-06],
[-1.0389e-05],
[-1.0503e-05],
[-1.0984e-05],
[-1.2046e-05],
[-1.1267e-05],
[-1.3704e-05],
[-1.3803e-05],
[-1.4681e-05],
[-1.3590e-05],
[-1.3052e-05],
[-1.3095e-05],
[-1.2089e-05],
[-1.1508e-05],
[-9.4827e-06],
[-9.1004e-06],
[-8.3497e-06],
[-6.6075e-06],
[-7.7690e-06],
[-8.0805e-06],
[-7.2166e-06],
[-6.5508e-06],
[-5.9985e-06],
[-6.2534e-06],
[-7.9956e-06],
[-8.3780e-06],
[-9.2987e-06],
[-9.7519e-06],
[-9.8227e-06],
[-9.7944e-06],
[-1.0007e-05],
[-1.0219e-05],
[-9.4544e-06],
[-9.8652e-06],
[-1.1310e-05],
[-1.1012e-05],
[-1.0857e-05],
[-9.9785e-06],
[-9.6811e-06],
[-7.8114e-06],
[-7.6273e-06],
[-7.3440e-06],
[-7.2874e-06],
[-7.2874e-06],
[-7.2732e-06],
[-7.6698e-06],
[-7.4998e-06],
[-6.2958e-06],
[-7.0040e-06],
[-7.3016e-06],
[-8.3780e-06],
[-8.4488e-06],
[-7.3723e-06],
[-7.5140e-06],
[-7.5989e-06],
[-7.3581e-06],
[-7.5140e-06],
[-7.4998e-06],
[-7.4856e-06],
[-7.7690e-06],
[-8.1938e-06],
[-7.9530e-06],
[-6.5650e-06],
[-7.0608e-06]])
-->name: bo -->grad_requirs: True -->grad_value: tensor([-6.3140e-06])
-->name: Wch -->grad_requirs: True -->grad_value: tensor([[0.]])
-->name: Wcx -->grad_requirs: True -->grad_value: tensor([[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.]])
-->name: bc -->grad_requirs: True -->grad_value: tensor([0.])
-->name: predict.weight -->grad_requirs: True -->grad_value: tensor([[-0.0017],
[-0.0024],
[-0.0017],
[-0.0012],
[-0.0023],
[-0.0013],
[-0.0018],
[-0.0012],
[-0.0021],
[-0.0016],
[-0.0023],
[-0.0013],
[-0.0018],
[-0.0017],
[-0.0020],
[-0.0022],
[-0.0017],
[-0.0017],
[-0.0013],
[-0.0010],
[-0.0017],
[-0.0023],
[-0.0013],
[-0.0015],
[-0.0017],
[-0.0015],
[-0.0020],
[-0.0014],
[-0.0018],
[-0.0011],
[-0.0010],
[-0.0011],
[-0.0010],
[-0.0011],
[-0.0008],
[-0.0007],
[-0.0016],
[-0.0012],
[-0.0008],
[-0.0012],
[-0.0014],
[-0.0008],
[-0.0009],
[-0.0017],
[-0.0015],
[-0.0017],
[-0.0010],
[-0.0015],
[-0.0013],
[-0.0018],
[-0.0018],
[-0.0009],
[-0.0017],
[-0.0012],
[-0.0014],
[-0.0016],
[-0.0009],
[-0.0011],
[-0.0010],
[-0.0014],
[-0.0015],
[-0.0007],
[-0.0004],
[-0.0008],
[-0.0011],
[-0.0017],
[-0.0018],
[-0.0010],
[-0.0007],
[-0.0007],
[-0.0018],
[-0.0009],
[-0.0019],
[-0.0015],
[-0.0012],
[-0.0015],
[-0.0003],
[-0.0006],
[-0.0012],
[-0.0016],
[-0.0008],
[-0.0008],
[-0.0010],
[-0.0015],
[-0.0013],
[-0.0006],
[-0.0010],
[-0.0016],
[-0.0007],
[-0.0007],
[-0.0015],
[-0.0016],
[-0.0019],
[-0.0020],
[-0.0015],
[-0.0008],
[-0.0019],
[-0.0016],
[-0.0010],
[-0.0013],
[-0.0006],
[-0.0005],
[-0.0005],
[-0.0009],
[-0.0004],
[-0.0013],
[-0.0004],
[-0.0005],
[-0.0012],
[-0.0005],
[-0.0009],
[-0.0004],
[-0.0014],
[-0.0013],
[-0.0008],
[-0.0012],
[-0.0014],
[-0.0008],
[-0.0007],
[-0.0014],
[-0.0010],
[-0.0008],
[-0.0008],
[-0.0017],
[-0.0014],
[-0.0009],
[-0.0012],
[-0.0010],
[-0.0009],
[-0.0010],
[-0.0011],
[-0.0007],
[-0.0012],
[-0.0011],
[-0.0002],
[-0.0010],
[-0.0005],
[-0.0005],
[-0.0006],
[-0.0013],
[-0.0005],
[-0.0006],
[-0.0009],
[-0.0009],
[-0.0004],
[-0.0011],
[-0.0002],
[-0.0002],
[-0.0007],
[-0.0003]])
-->name: predict.bias -->grad_requirs: True -->grad_value: tensor([-0.0359, -0.0492, -0.0352, -0.0250, -0.0477, -0.0271, -0.0369, -0.0240,
-0.0433, -0.0328, -0.0467, -0.0273, -0.0364, -0.0351, -0.0420, -0.0450,
-0.0357, -0.0360, -0.0273, -0.0202, -0.0351, -0.0467, -0.0276, -0.0312,
-0.0349, -0.0307, -0.0415, -0.0289, -0.0372, -0.0227, -0.0214, -0.0230,
-0.0215, -0.0224, -0.0169, -0.0146, -0.0330, -0.0252, -0.0156, -0.0253,
-0.0298, -0.0163, -0.0185, -0.0346, -0.0301, -0.0348, -0.0206, -0.0314,
-0.0266, -0.0366, -0.0376, -0.0184, -0.0356, -0.0237, -0.0291, -0.0334,
-0.0182, -0.0230, -0.0205, -0.0285, -0.0313, -0.0140, -0.0088, -0.0159,
-0.0223, -0.0357, -0.0380, -0.0209, -0.0136, -0.0137, -0.0366, -0.0177,
-0.0388, -0.0300, -0.0237, -0.0312, -0.0061, -0.0132, -0.0242, -0.0328,
-0.0161, -0.0158, -0.0196, -0.0300, -0.0260, -0.0132, -0.0197, -0.0324,
-0.0140, -0.0140, -0.0304, -0.0338, -0.0398, -0.0416, -0.0311, -0.0167,
-0.0399, -0.0323, -0.0207, -0.0264, -0.0124, -0.0099, -0.0099, -0.0180,
-0.0089, -0.0271, -0.0088, -0.0110, -0.0249, -0.0113, -0.0179, -0.0091,
-0.0284, -0.0271, -0.0174, -0.0239, -0.0292, -0.0172, -0.0147, -0.0288,
-0.0207, -0.0156, -0.0171, -0.0344, -0.0297, -0.0187, -0.0239, -0.0198,
-0.0192, -0.0202, -0.0232, -0.0154, -0.0253, -0.0222, -0.0043, -0.0216,
-0.0106, -0.0108, -0.0129, -0.0270, -0.0097, -0.0129, -0.0178, -0.0183,
-0.0083, -0.0221, -0.0044, -0.0049, -0.0154, -0.0062])