Hello,
I’m new to Pytorch and I’m trying to implement LSTM without utilizing the nn module. However, for some reason, autograd doesn’t seem to work well as the network as the network doesn’t learn. The network I’m implementing is a simple LSTM layer followed by a linear layer. I have only one output at the final time instant. I post below my code, your help would be highly appreciated!
dtype = torch.float
def initialize_weights(x,n_hidden):
z_size=n_hidden+x.shape[2]
wf=torch.tensor((torch.randn(n_hidden,z_size)*0.1),dtype=dtype,requires_grad=True)
bf=torch.zeros(n_hidden,dtype=dtype,requires_grad=True)
wi=torch.tensor((torch.randn(n_hidden,z_size)*0.1),dtype=dtype,requires_grad=True)
bi=torch.zeros(n_hidden,dtype=dtype,requires_grad=True)
wc=torch.tensor((torch.randn(n_hidden,z_size)*0.1),dtype=dtype,requires_grad=True)
bc=torch.zeros(n_hidden,dtype=dtype,requires_grad=True)
wo=torch.tensor((torch.randn(n_hidden,z_size)*0.1),dtype=dtype,requires_grad=True)
bo=torch.zeros(n_hidden,dtype=dtype,requires_grad=True)
w=torch.tensor((torch.randn(1,n_hidden)*0.1),dtype=dtype,requires_grad=True)
b=torch.zeros(1,dtype=dtype,requires_grad=True)
return wf,bf,wi,bi,wc,bc,wo,bo,w,b
def forward_pass(x, wf,bf,wi,bi,wc,bc,wo,bo,w,b,n_hidden):
t_max, n, _ = x.shape
h = torch.zeros(t_max, n,n_hidden, dtype=dtype)
c_prev=torch.zeros(n,n_hidden,dtype=dtype)
for t in range(t_max):
if t==0:
z=torch.cat((x[0],h[t]),1)
else:
z=torch.cat((x[t],h[t-1]),1)
f=F.sigmoid(z.mm(wf.transpose(0,1))+bf)
i=F.sigmoid(z.mm(wi.transpose(0,1))+bi)
c_tilda=F.tanh(z.mm(wc.transpose(0,1))+bc)
c=f*c_prev + i*c_tilda
o=F.sigmoid(z.mm(wo.transpose(0,1))+bo)
h[t]=o*F.tanh(c)
c_prev=c
y_pred=h[t_max-1].mm(w.transpose(0,1)) + b #output linear layer
return y_pred
def train_LSTM(x,y,n_hidden,learning_rate,epochs):
error=[]
rmse=[]
wf,bf,wi,bi,wc,bc,wo,bo,w,b=initialize_weights(x,n_hidden)
for i in range(epochs):
y_pred=forward_pass(x, wf,bf,wi,bi,wc,bc,wo,bo,w,b,n_hidden)
loss = 0.5* ((y_pred - y).pow(2).sum())
error.append(loss.item())
rmse.append(np.sqrt((error[i]*2)/y.shape[0]))
print("Epoch",i+1)
print("Training loss=", error[i])
print("RMSE=", rmse[i])
print("##########")
loss.backward()
with torch.no_grad():
wf -= learning_rate * wf.grad
bf -= learning_rate * bf.grad
wi -= learning_rate * wi.grad
bi -= learning_rate * bi.grad
wc -= learning_rate * wc.grad
bc -= learning_rate * bc.grad
wo -= learning_rate * wo.grad
bo -= learning_rate * bo.grad
w -= learning_rate * w.grad
b -= learning_rate * b.grad
# Manually zero the gradients after updating weights
wf.grad.zero_()
bf.grad.zero_()
wi.grad.zero_()
bi.grad.zero_()
wc.grad.zero_()
bc.grad.zero_()
wo.grad.zero_()
bo.grad.zero_()
w.grad.zero_()
b.grad.zero_()
return error